Skip to main content

zagens_runtime/rlm/
bridge.rs

1//! RPC bridge that services `llm_query` / `rlm_query` calls coming back
2//! from the long-lived Python REPL during an RLM turn.
3//!
4//! This is the spiritual successor to the HTTP sidecar from earlier
5//! versions — except instead of binding a localhost port and routing
6//! through `urllib`, requests come in through stdin/stdout and we just
7//! call the LLM client directly here in Rust.
8//!
9//! The bridge tracks cumulative token usage and the recursion budget. For
10//! `Rlm` / `RlmBatch` requests it recursively calls `run_rlm_turn_inner`
11//! at depth-1; the future-type cycle (bridge → run_rlm_turn_inner →
12//! bridge) is broken by `run_rlm_turn_inner` returning a boxed dyn future.
13
14use std::sync::Arc;
15use std::time::Duration;
16use std::{future::Future, pin::Pin};
17
18use anyhow::Result;
19use futures_util::future::join_all;
20use tokio::sync::Mutex;
21
22use crate::llm_client::LlmClient;
23use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, SystemPrompt, Usage};
24use crate::repl::runtime::{BatchResp, RpcDispatcher, RpcRequest, RpcResponse, SingleResp};
25use crate::utils::spawn_supervised;
26
27/// Per-child completion timeout — same as the previous sidecar default.
28const CHILD_TIMEOUT_SECS: u64 = 120;
29/// Default `max_tokens` for one-shot child completions.
30const DEFAULT_CHILD_MAX_TOKENS: u32 = 4096;
31/// Hard cap on prompts per batch RPC.
32pub const MAX_BATCH: usize = 16;
33
34/// Object-safe slice of the LLM client interface that the RLM bridge needs.
35///
36/// `LlmClient` itself uses native async trait methods, which are not dyn-safe.
37/// The bridge only needs non-streaming completions, so this boxed-future shim
38/// gives tests a clean mock seam without changing the wider provider trait.
39pub(crate) trait RlmLlmClient: Send + Sync {
40    fn create_message_boxed(
41        &self,
42        request: MessageRequest,
43    ) -> Pin<Box<dyn Future<Output = Result<MessageResponse>> + Send + '_>>;
44}
45
46impl<T> RlmLlmClient for T
47where
48    T: LlmClient + Send + Sync + ?Sized,
49{
50    fn create_message_boxed(
51        &self,
52        request: MessageRequest,
53    ) -> Pin<Box<dyn Future<Output = Result<MessageResponse>> + Send + '_>> {
54        Box::pin(self.create_message(request))
55    }
56}
57
58/// State shared with the bridge across all RPC calls in one turn.
59pub struct RlmBridge {
60    client: Arc<dyn crate::llm_client::LlmClient>,
61    child_model: String,
62    /// Recursion budget remaining for `Rlm` / `RlmBatch` requests. When
63    /// zero, those requests fall back to plain `Llm` completions.
64    depth_remaining: u32,
65    usage: Arc<Mutex<Usage>>,
66}
67
68impl RlmBridge {
69    pub(crate) fn new(
70        client: Arc<dyn crate::llm_client::LlmClient>,
71        child_model: String,
72        depth_remaining: u32,
73    ) -> Self {
74        Self {
75            client,
76            child_model,
77            depth_remaining,
78            usage: Arc::new(Mutex::new(Usage::default())),
79        }
80    }
81
82    pub fn usage_handle(&self) -> Arc<Mutex<Usage>> {
83        Arc::clone(&self.usage)
84    }
85
86    async fn dispatch_llm(
87        &self,
88        prompt: String,
89        _model: Option<String>,
90        max_tokens: Option<u32>,
91        system: Option<String>,
92    ) -> SingleResp {
93        let request = MessageRequest {
94            // The Python helper accepts `model=` for older snippets, but it is
95            // intentionally not authoritative. RLM child calls are pinned to
96            // the tool's configured child model so model-generated Python
97            // cannot silently upgrade cheap fanout work to an expensive model.
98            model: self.child_model.clone(),
99            messages: vec![Message {
100                role: "user".to_string(),
101                content: vec![ContentBlock::Text {
102                    text: prompt,
103                    cache_control: None,
104                }],
105            }],
106            max_tokens: max_tokens.unwrap_or(DEFAULT_CHILD_MAX_TOKENS),
107            system: system.map(SystemPrompt::Text),
108            tools: None,
109            tool_choice: None,
110            metadata: None,
111            thinking: None,
112            reasoning_effort: None,
113            stream: Some(false),
114            temperature: Some(0.4_f32),
115            top_p: Some(0.9_f32),
116        };
117
118        let fut = self.client.create_message(request);
119        let response =
120            match tokio::time::timeout(Duration::from_secs(CHILD_TIMEOUT_SECS), fut).await {
121                Ok(Ok(r)) => r,
122                Ok(Err(e)) => {
123                    return SingleResp {
124                        text: String::new(),
125                        error: Some(format!("llm_query failed: {e}")),
126                    };
127                }
128                Err(_) => {
129                    return SingleResp {
130                        text: String::new(),
131                        error: Some(format!("llm_query timed out after {CHILD_TIMEOUT_SECS}s")),
132                    };
133                }
134            };
135
136        let text = response
137            .content
138            .iter()
139            .filter_map(|b| match b {
140                ContentBlock::Text { text, .. } => Some(text.as_str()),
141                _ => None,
142            })
143            .collect::<Vec<_>>()
144            .join("\n");
145
146        {
147            let mut u = self.usage.lock().await;
148            u.input_tokens = u.input_tokens.saturating_add(response.usage.input_tokens);
149            u.output_tokens = u.output_tokens.saturating_add(response.usage.output_tokens);
150        }
151
152        SingleResp { text, error: None }
153    }
154
155    async fn dispatch_llm_batch(&self, prompts: Vec<String>, _model: Option<String>) -> BatchResp {
156        if let Some(resp) = batch_guard(prompts.len()) {
157            return resp;
158        }
159
160        let model = Arc::new(self.child_model.clone());
161
162        let futures = prompts.into_iter().map(|prompt| {
163            let model = Arc::clone(&model);
164            async move {
165                self.dispatch_llm((*prompt).to_string(), Some((*model).clone()), None, None)
166                    .await
167            }
168        });
169
170        BatchResp {
171            results: join_all(futures).await,
172        }
173    }
174
175    async fn dispatch_rlm(&self, prompt: String, _model: Option<String>) -> SingleResp {
176        if self.depth_remaining == 0 {
177            // Budget exhausted — fall back to a one-shot child completion
178            // rather than returning an error. Matches the paper's behaviour
179            // ("sub_RLM gracefully degrades to llm_query at depth=0").
180            return self.dispatch_llm(prompt, None, None, None).await;
181        }
182
183        // Build a drain channel to absorb status events from the nested
184        // turn (we don't surface them; this dispatch is invisible to the
185        // outer agent stream).
186        let (tx, mut rx) = tokio::sync::mpsc::channel(64);
187        let drain = spawn_supervised(
188            "rlm-bridge-drain",
189            std::panic::Location::caller(),
190            async move { while rx.recv().await.is_some() {} },
191        );
192
193        let child_model = self.child_model.clone();
194
195        // Recursive call. The dyn-erasure on `run_rlm_turn_inner` breaks
196        // the `bridge → turn → bridge` opaque-future cycle.
197        let result = super::turn::run_rlm_turn_inner(
198            self.client.clone(),
199            child_model.clone(),
200            prompt,
201            None,
202            child_model,
203            tx,
204            self.depth_remaining.saturating_sub(1),
205        )
206        .await;
207
208        drain.abort();
209
210        {
211            let mut u = self.usage.lock().await;
212            u.input_tokens = u.input_tokens.saturating_add(result.usage.input_tokens);
213            u.output_tokens = u.output_tokens.saturating_add(result.usage.output_tokens);
214        }
215
216        SingleResp {
217            text: result.answer,
218            error: result.error,
219        }
220    }
221
222    async fn dispatch_rlm_batch(&self, prompts: Vec<String>, _model: Option<String>) -> BatchResp {
223        if let Some(resp) = batch_guard(prompts.len()) {
224            return resp;
225        }
226
227        let futures = prompts
228            .into_iter()
229            .map(|p| async move { self.dispatch_rlm(p, None).await });
230        BatchResp {
231            results: join_all(futures).await,
232        }
233    }
234}
235
236fn batch_guard(prompt_count: usize) -> Option<BatchResp> {
237    if prompt_count == 0 {
238        return Some(BatchResp { results: vec![] });
239    }
240    if prompt_count > MAX_BATCH {
241        return Some(BatchResp {
242            results: (0..prompt_count)
243                .map(|_| SingleResp {
244                    text: String::new(),
245                    error: Some(format!("batch too large: {prompt_count} > {MAX_BATCH}")),
246                })
247                .collect(),
248        });
249    }
250    None
251}
252
253impl RpcDispatcher for RlmBridge {
254    fn dispatch<'a>(
255        &'a self,
256        req: RpcRequest,
257    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = RpcResponse> + Send + 'a>> {
258        Box::pin(async move {
259            match req {
260                RpcRequest::Llm {
261                    prompt,
262                    model,
263                    max_tokens,
264                    system,
265                } => {
266                    RpcResponse::Single(self.dispatch_llm(prompt, model, max_tokens, system).await)
267                }
268                RpcRequest::LlmBatch { prompts, model } => {
269                    RpcResponse::Batch(self.dispatch_llm_batch(prompts, model).await)
270                }
271                RpcRequest::Rlm { prompt, model } => {
272                    RpcResponse::Single(self.dispatch_rlm(prompt, model).await)
273                }
274                RpcRequest::RlmBatch { prompts, model } => {
275                    RpcResponse::Batch(self.dispatch_rlm_batch(prompts, model).await)
276                }
277            }
278        })
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::llm_client::mock::MockLlmClient;
286
287    fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse {
288        MessageResponse {
289            id: "mock_msg".to_string(),
290            r#type: "message".to_string(),
291            role: "assistant".to_string(),
292            content: vec![ContentBlock::Text {
293                text: text.to_string(),
294                cache_control: None,
295            }],
296            model: "mock-model".to_string(),
297            stop_reason: Some("end_turn".to_string()),
298            stop_sequence: None,
299            container: None,
300            usage: Usage {
301                input_tokens,
302                output_tokens,
303                ..Usage::default()
304            },
305        }
306    }
307
308    fn bridge_for(mock: Arc<MockLlmClient>, depth_remaining: u32) -> RlmBridge {
309        let client: Arc<dyn crate::llm_client::LlmClient> = mock;
310        RlmBridge::new(client, "child-model".to_string(), depth_remaining)
311    }
312
313    #[test]
314    fn batch_guard_allows_non_empty_batches_at_the_cap() {
315        assert!(batch_guard(MAX_BATCH).is_none());
316    }
317
318    #[test]
319    fn batch_guard_returns_empty_response_for_empty_batches() {
320        let response = batch_guard(0).expect("empty batch should be handled");
321        assert!(response.results.is_empty());
322    }
323
324    #[test]
325    fn batch_guard_returns_one_error_per_oversized_prompt() {
326        let response = batch_guard(MAX_BATCH + 2).expect("oversized batch should be handled");
327        assert_eq!(response.results.len(), MAX_BATCH + 2);
328        assert!(response.results.iter().all(|result| {
329            result.text.is_empty()
330                && result
331                    .error
332                    .as_deref()
333                    .is_some_and(|err| err.contains("batch too large"))
334        }));
335    }
336
337    #[tokio::test]
338    async fn llm_dispatch_pins_configured_child_model() {
339        let mock = Arc::new(MockLlmClient::new(Vec::new()));
340        mock.push_message_response(mock_response("child answer", 7, 11));
341        let bridge = bridge_for(Arc::clone(&mock), 1);
342
343        let response = bridge
344            .dispatch(RpcRequest::Llm {
345                prompt: "child prompt".to_string(),
346                model: Some("override-model".to_string()),
347                max_tokens: Some(123),
348                system: Some("child system".to_string()),
349            })
350            .await;
351
352        match response {
353            RpcResponse::Single(single) => {
354                assert_eq!(single.text, "child answer");
355                assert!(single.error.is_none());
356            }
357            other => panic!("expected single response, got {other:?}"),
358        }
359
360        let captured = mock.captured_requests();
361        assert_eq!(captured.len(), 1);
362        assert_eq!(captured[0].model, "child-model");
363        assert_eq!(captured[0].max_tokens, 123);
364        assert_eq!(
365            captured[0].system,
366            Some(SystemPrompt::Text("child system".to_string()))
367        );
368
369        let usage = bridge.usage.lock().await;
370        assert_eq!(usage.input_tokens, 7);
371        assert_eq!(usage.output_tokens, 11);
372    }
373
374    #[tokio::test]
375    async fn llm_batch_dispatch_pins_configured_child_model() {
376        let mock = Arc::new(MockLlmClient::new(Vec::new()));
377        mock.push_message_response(mock_response("one", 1, 2));
378        mock.push_message_response(mock_response("two", 3, 4));
379        mock.push_message_response(mock_response("three", 5, 6));
380        let bridge = bridge_for(Arc::clone(&mock), 1);
381
382        let response = bridge
383            .dispatch(RpcRequest::LlmBatch {
384                prompts: vec!["a".to_string(), "b".to_string(), "c".to_string()],
385                model: Some("batch-model".to_string()),
386            })
387            .await;
388
389        match response {
390            RpcResponse::Batch(batch) => {
391                let texts: Vec<_> = batch
392                    .results
393                    .iter()
394                    .map(|result| result.text.as_str())
395                    .collect();
396                assert_eq!(texts, ["one", "two", "three"]);
397                assert!(batch.results.iter().all(|result| result.error.is_none()));
398            }
399            other => panic!("expected batch response, got {other:?}"),
400        }
401
402        let captured = mock.captured_requests();
403        assert_eq!(captured.len(), 3);
404        assert!(
405            captured
406                .iter()
407                .all(|request| request.model == "child-model")
408        );
409
410        let usage = bridge.usage.lock().await;
411        assert_eq!(usage.input_tokens, 9);
412        assert_eq!(usage.output_tokens, 12);
413    }
414
415    #[tokio::test]
416    async fn rlm_dispatch_at_depth_zero_pins_configured_child_model() {
417        let mock = Arc::new(MockLlmClient::new(Vec::new()));
418        mock.push_message_response(mock_response("fallback answer", 3, 5));
419        let bridge = bridge_for(Arc::clone(&mock), 0);
420
421        let response = bridge
422            .dispatch(RpcRequest::Rlm {
423                prompt: "nested prompt".to_string(),
424                model: Some("override-model".to_string()),
425            })
426            .await;
427
428        match response {
429            RpcResponse::Single(single) => {
430                assert_eq!(single.text, "fallback answer");
431                assert!(single.error.is_none());
432            }
433            other => panic!("expected single response, got {other:?}"),
434        }
435
436        let usage = bridge.usage.lock().await;
437        assert_eq!(usage.input_tokens, 3);
438        assert_eq!(usage.output_tokens, 5);
439
440        let captured = mock.captured_requests();
441        assert_eq!(captured.len(), 1);
442        assert_eq!(captured[0].model, "child-model");
443    }
444}