actr_runtime/
context.rs

1//! Runtime Context Implementation
2//!
3//! Implements the Context trait defined in actr-framework.
4
5use crate::inbound::{DataStreamRegistry, MediaFrameRegistry};
6use crate::outbound::OutGate;
7use crate::wire::webrtc::SignalingClient;
8use actr_framework::{Bytes, Context, DataStream, Dest, MediaSample};
9use actr_protocol::{
10    AIdCredential, ActorResult, ActrError, ActrId, ActrType, ProtocolError, RouteCandidatesRequest,
11    RpcEnvelope, RpcRequest, route_candidates_request,
12};
13use async_trait::async_trait;
14use futures_util::future::BoxFuture;
15use std::sync::Arc;
16
17/// RuntimeContext - Runtime's implementation of Context trait
18///
19/// # 设计特性
20///
21/// - **零虚函数**:内部使用 OutGate enum dispatch(非 dyn)
22/// - **智能路由**:根据 Dest 自动选择 InprocOut 或 OutprocOut
23/// - **完整实现**:包含 call/tell 的完整逻辑(编码、发送、解码)
24/// - **类型安全**:泛型方法提供编译时类型检查
25///
26/// # 性能
27///
28/// - OutGate 是 enum,使用静态分发
29/// - 编译器可完全内联整个调用链
30/// - 零虚函数调用开销
31#[derive(Clone)]
32pub struct RuntimeContext {
33    self_id: ActrId,
34    caller_id: Option<ActrId>,
35    trace_id: String,
36    request_id: String,
37    inproc_gate: OutGate,                          // Shell/Local 调用 - 立即可用
38    outproc_gate: Option<OutGate>,                 // 远程 Actor 调用 - 延迟初始化
39    data_stream_registry: Arc<DataStreamRegistry>, // DataStream 回调注册表
40    media_frame_registry: Arc<MediaFrameRegistry>, // MediaTrack 回调注册表
41    signaling_client: Arc<dyn SignalingClient>,
42    credential: AIdCredential,
43}
44
45impl RuntimeContext {
46    /// 创建新的 RuntimeContext
47    ///
48    /// # 参数
49    ///
50    /// - `self_id`: 当前 Actor 的 ID
51    /// - `caller_id`: 调用方 Actor ID(可选)
52    /// - `trace_id`: 分布式追踪 ID
53    /// - `request_id`: 当前请求唯一 ID
54    /// - `inproc_gate`: 进程内通信 gate(立即可用)
55    /// - `outproc_gate`: 跨进程通信 gate(可能为 None,等待 WebRTC 初始化)
56    /// - `data_stream_registry`: DataStream 回调注册表
57    /// - `media_frame_registry`: MediaTrack 回调注册表
58    /// - `signaling_client`: 用于路由发现的信令客户端
59    /// - `credential`: 该 Actor 的凭证(调用信令接口时使用)
60    #[allow(clippy::too_many_arguments)] // Internal API - all parameters are required
61    pub fn new(
62        self_id: ActrId,
63        caller_id: Option<ActrId>,
64        trace_id: String,
65        request_id: String,
66        inproc_gate: OutGate,
67        outproc_gate: Option<OutGate>,
68        data_stream_registry: Arc<DataStreamRegistry>,
69        media_frame_registry: Arc<MediaFrameRegistry>,
70        signaling_client: Arc<dyn SignalingClient>,
71        credential: AIdCredential,
72    ) -> Self {
73        Self {
74            self_id,
75            caller_id,
76            trace_id,
77            request_id,
78            inproc_gate,
79            outproc_gate,
80            data_stream_registry,
81            media_frame_registry,
82            signaling_client,
83            credential,
84        }
85    }
86
87    /// 根据 Dest 选择合适的 gate
88    ///
89    /// - Dest::Shell → inproc_gate(立即可用)
90    /// - Dest::Local → inproc_gate(立即可用)
91    /// - Dest::Actor(_) → outproc_gate(需要检查是否已初始化)
92    #[inline]
93    fn select_gate(&self, dest: &Dest) -> ActorResult<&OutGate> {
94        match dest {
95            Dest::Shell | Dest::Local => Ok(&self.inproc_gate),
96            Dest::Actor(_) => self.outproc_gate.as_ref().ok_or_else(|| {
97                ProtocolError::Actr(ActrError::GateNotInitialized {
98                    message: "OutprocOutGate not initialized yet (WebRTC setup in progress)"
99                        .to_string(),
100                })
101            }),
102        }
103    }
104
105    /// 从 Dest 提取目标 ActrId
106    ///
107    /// - Dest::Shell → self_id(Workload → App 反向调用)
108    /// - Dest::Local → self_id(调用本地 Workload)
109    /// - Dest::Actor(id) → id(远程调用)
110    #[inline]
111    fn extract_target_id<'a>(&'a self, dest: &'a Dest) -> &'a ActrId {
112        match dest {
113            Dest::Shell | Dest::Local => &self.self_id,
114            Dest::Actor(id) => id,
115        }
116    }
117}
118
119#[async_trait]
120impl Context for RuntimeContext {
121    // ========== 数据访问方法 ==========
122
123    fn self_id(&self) -> &ActrId {
124        &self.self_id
125    }
126
127    fn caller_id(&self) -> Option<&ActrId> {
128        self.caller_id.as_ref()
129    }
130
131    fn trace_id(&self) -> &str {
132        &self.trace_id
133    }
134
135    fn request_id(&self) -> &str {
136        &self.request_id
137    }
138
139    // ========== 通信能力方法 ==========
140
141    async fn call<R: RpcRequest>(&self, target: &Dest, request: R) -> ActorResult<R::Response> {
142        use actr_protocol::prost::Message as ProstMessage;
143
144        // 1. 编码请求为 protobuf bytes
145        let payload: Bytes = request.encode_to_vec().into();
146
147        // 2. 从 RpcRequest trait 获取 route_key(编译时确定)
148        let route_key = R::route_key().to_string();
149
150        // 3. 构造 RpcEnvelope(继承当前 Context 的追踪信息)
151        let envelope = RpcEnvelope {
152            route_key,
153            payload: Some(payload),
154            error: None,
155            trace_id: self.trace_id.clone(), // 继承 trace_id,保持调用链追踪
156            request_id: uuid::Uuid::new_v4().to_string(), // 生成新的 request_id
157            metadata: vec![],
158            timeout_ms: 30000, // 默认 30 秒超时
159        };
160
161        // 4. 根据 Dest 选择 gate 并提取目标 ActrId(Shell/Local 立即可用,Actor 需要检查)
162        let gate = self.select_gate(target)?;
163        let target_id = self.extract_target_id(target);
164
165        // 5. 通过 OutGate enum dispatch 发送(零虚函数调用!)
166        let response_bytes = gate.send_request(target_id, envelope).await?;
167
168        // 6. 解码响应(类型安全:R::Response)
169        R::Response::decode(&*response_bytes).map_err(|e| {
170            ProtocolError::Actr(ActrError::DecodeFailure {
171                message: format!(
172                    "Failed to decode {}: {}",
173                    std::any::type_name::<R::Response>(),
174                    e
175                ),
176            })
177        })
178    }
179
180    async fn tell<R: RpcRequest>(&self, target: &Dest, message: R) -> ActorResult<()> {
181        // 1. 编码消息
182        let payload: Bytes = message.encode_to_vec().into();
183
184        // 2. 获取 route_key
185        let route_key = R::route_key().to_string();
186
187        // 3. 构造 RpcEnvelope(fire-and-forget 语义)
188        let envelope = RpcEnvelope {
189            route_key,
190            payload: Some(payload),
191            error: None,
192            trace_id: self.trace_id.clone(),
193            request_id: uuid::Uuid::new_v4().to_string(),
194            metadata: vec![],
195            timeout_ms: 0, // 0 表示不等待响应
196        };
197
198        // 4. 根据 Dest 选择 gate 并提取目标 ActrId(Shell/Local 立即可用,Actor 需要检查)
199        let gate = self.select_gate(target)?;
200        let target_id = self.extract_target_id(target);
201
202        // 5. 通过 OutGate enum dispatch 发送
203        gate.send_message(target_id, envelope).await
204    }
205
206    // ========== Fast Path: DataStream Methods ==========
207
208    async fn register_stream<F>(&self, stream_id: String, callback: F) -> ActorResult<()>
209    where
210        F: Fn(DataStream, ActrId) -> BoxFuture<'static, ActorResult<()>> + Send + Sync + 'static,
211    {
212        tracing::debug!(
213            "📊 Registering DataStream callback for stream_id: {}",
214            stream_id
215        );
216        self.data_stream_registry
217            .register(stream_id, Arc::new(callback));
218        Ok(())
219    }
220
221    async fn unregister_stream(&self, stream_id: &str) -> ActorResult<()> {
222        tracing::debug!(
223            "🚫 Unregistering DataStream callback for stream_id: {}",
224            stream_id
225        );
226        self.data_stream_registry.unregister(stream_id);
227        Ok(())
228    }
229
230    async fn send_data_stream(&self, target: &Dest, chunk: DataStream) -> ActorResult<()> {
231        use actr_protocol::prost::Message as ProstMessage;
232
233        // 1. Serialize DataStream to bytes
234        let payload = chunk.encode_to_vec();
235
236        tracing::debug!(
237            "📤 Sending DataStream: stream_id={}, sequence={}, size={} bytes",
238            chunk.stream_id,
239            chunk.sequence,
240            payload.len()
241        );
242
243        // 2. Select gate based on Dest
244        let gate = self.select_gate(target)?;
245        let target_id = self.extract_target_id(target);
246
247        // 3. Send via OutGate with appropriate PayloadType
248        // Use StreamReliable for reliable ordered transmission
249        // TODO: Allow user to choose between StreamReliable and StreamLatencyFirst
250        gate.send_data_stream(
251            target_id,
252            actr_protocol::PayloadType::StreamReliable,
253            bytes::Bytes::from(payload),
254        )
255        .await
256    }
257
258    async fn discover_route_candidate(&self, target_type: &ActrType) -> ActorResult<ActrId> {
259        if !self.signaling_client.is_connected() {
260            return Err(ProtocolError::TransportError(
261                "Signaling client is not connected.".to_string(),
262            ));
263        }
264
265        let criteria = route_candidates_request::NodeSelectionCriteria {
266            candidate_count: 1,
267            ranking_factors: Vec::new(),
268            minimal_dependency_requirement: None,
269            minimal_health_requirement: None,
270        };
271
272        let request = RouteCandidatesRequest {
273            target_type: target_type.clone(),
274            criteria: Some(criteria),
275            client_location: None,
276        };
277
278        let response = self
279            .signaling_client
280            .send_route_candidates_request(self.self_id.clone(), self.credential.clone(), request)
281            .await
282            .map_err(|e| {
283                ProtocolError::TransportError(format!("Route candidates request failed: {e}"))
284            })?;
285
286        match response.result {
287            Some(actr_protocol::route_candidates_response::Result::Success(ok)) => {
288                ok.candidates.into_iter().next().ok_or_else(|| {
289                    ProtocolError::TargetNotFound(format!(
290                        "No route candidates for type {}.{}",
291                        target_type.manufacturer, target_type.name
292                    ))
293                })
294            }
295            Some(actr_protocol::route_candidates_response::Result::Error(err)) => {
296                Err(ProtocolError::TransportError(format!(
297                    "Route candidates error {}: {}",
298                    err.code, err.message
299                )))
300            }
301            None => Err(ProtocolError::TransportError(
302                "Route candidates response missing result".to_string(),
303            )),
304        }
305    }
306
307    // ========== Fast Path: MediaTrack Methods ==========
308
309    async fn register_media_track<F>(&self, track_id: String, callback: F) -> ActorResult<()>
310    where
311        F: Fn(MediaSample, ActrId) -> BoxFuture<'static, ActorResult<()>> + Send + Sync + 'static,
312    {
313        tracing::debug!(
314            "📹 Registering MediaTrack callback for track_id: {}",
315            track_id
316        );
317        self.media_frame_registry
318            .register(track_id, Arc::new(callback));
319        Ok(())
320    }
321
322    async fn unregister_media_track(&self, track_id: &str) -> ActorResult<()> {
323        tracing::debug!(
324            "📹 Unregistering MediaTrack callback for track_id: {}",
325            track_id
326        );
327        self.media_frame_registry.unregister(track_id);
328        Ok(())
329    }
330
331    async fn send_media_sample(
332        &self,
333        target: &Dest,
334        track_id: &str,
335        sample: MediaSample,
336    ) -> ActorResult<()> {
337        // 1. Select appropriate gate based on Dest
338        let gate = self.select_gate(target)?;
339
340        // 2. Extract target ActrId
341        let target_id = self.extract_target_id(target);
342
343        // 3. Send via OutGate (delegates to WebRTC Track)
344        gate.send_media_sample(target_id, track_id, sample).await
345    }
346}