Skip to main content

car_inference/
runner.rs

1//! Foreign-implemented inference runner.
2//!
3//! Closes [Parslee-ai/car-releases#24]. The runner pattern lets a host
4//! (Node.js, Python, Swift, Kotlin, or a remote JSON-RPC client) own
5//! the wire format for cloud chat APIs while CAR stays in the
6//! lifecycle path — observing every event, applying policy, recording
7//! to the eventlog, supporting replay.
8//!
9//! When a model schema declares `source: ModelSource::Delegated { .. }`,
10//! [`InferenceEngine::generate_tracked_stream`] checks the
11//! process-wide runner slot ([`set_inference_runner`]); if a runner
12//! is installed, the request is handed to it and its emitted events
13//! flow through the stream's `Receiver<StreamEvent>` exactly as if a
14//! native backend had produced them.
15//!
16//! The two-direction contract:
17//! - Rust → host: `run(request)` is invoked with a fully-formed
18//!   [`GenerateRequest`]. The trait is async so the host can do its
19//!   own HTTP / SDK work. The host receives an event emitter it
20//!   should call as chunks arrive.
21//! - host → Rust: every chunk becomes a `StreamEvent` via the
22//!   provided emitter. Final result returned as the trait method's
23//!   return value (text + tool_calls).
24//!
25//! Only one runner can be registered at a time (mirrors
26//! [`car_multi::AgentRunner`]'s singleton constraint). Re-registering
27//! overwrites the slot.
28//!
29//! [Parslee-ai/car-releases#24]: https://github.com/Parslee-ai/car-releases/issues/24
30
31use std::sync::{Arc, OnceLock, RwLock};
32
33use serde::{Deserialize, Serialize};
34
35use crate::stream::StreamEvent;
36use crate::tasks::generate::GenerateRequest;
37use crate::InferenceError;
38
39/// Final result returned by the runner once the stream is complete.
40/// Mirrors what [`crate::StreamAccumulator::finish`] yields for native
41/// backends — runners are expected to aggregate as they emit so the
42/// returned text matches the concatenation of `TextDelta`s.
43#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct RunnerResult {
45    pub text: String,
46    #[serde(default)]
47    pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
48}
49
50/// Errors a runner can surface back to the engine. Kept as a string
51/// payload so foreign hosts (NAPI, PyO3, UniFFI) can populate it
52/// without sharing concrete error types.
53#[derive(Debug, Clone, thiserror::Error)]
54pub enum RunnerError {
55    /// The runner declined the request (e.g., the schema's `hint`
56    /// didn't match any provider it knows about).
57    #[error("runner declined: {0}")]
58    Declined(String),
59    /// The runner attempted the request but failed at the wire layer
60    /// (HTTP error, auth failure, provider rate limit, etc.).
61    #[error("runner failed: {0}")]
62    Failed(String),
63}
64
65impl From<RunnerError> for InferenceError {
66    fn from(value: RunnerError) -> Self {
67        InferenceError::InferenceFailed(value.to_string())
68    }
69}
70
71/// Sink the runner uses to emit stream events as they arrive from the
72/// upstream provider. The runner calls `emit(event)` for every chunk;
73/// CAR receives it through the [`StreamEvent`] channel returned from
74/// [`crate::InferenceEngine::generate_tracked_stream`].
75///
76/// Cloning shares the underlying channel — a runner that fans out
77/// (e.g., separate workers per tool call) can clone the emitter once
78/// per worker.
79#[derive(Clone)]
80pub struct EventEmitter {
81    tx: tokio::sync::mpsc::Sender<StreamEvent>,
82}
83
84impl EventEmitter {
85    pub(crate) fn new(tx: tokio::sync::mpsc::Sender<StreamEvent>) -> Self {
86        Self { tx }
87    }
88
89    /// Emit one event. Best-effort: if the receiver was dropped
90    /// (caller stopped listening), the event is silently discarded.
91    /// The runner can detect this via [`Self::is_closed`] and bail
92    /// early if it wants to stop the upstream call.
93    pub async fn emit(&self, event: StreamEvent) {
94        let _ = self.tx.send(event).await;
95    }
96
97    /// Has the receiver been dropped? Runners that pump events
98    /// indefinitely should poll this to stop on consumer cancellation.
99    pub fn is_closed(&self) -> bool {
100        self.tx.is_closed()
101    }
102}
103
104/// The trait foreign hosts implement to handle delegated inference.
105///
106/// Implementations are typically thin shims around a host-owned SDK
107/// (Vercel AI SDK on Node, Anthropic's Python client, etc.). They
108/// translate the supplied [`GenerateRequest`] to a wire-format call,
109/// stream chunks back through the [`EventEmitter`], and return a
110/// [`RunnerResult`] with the aggregated final text.
111#[async_trait::async_trait]
112pub trait InferenceRunner: Send + Sync {
113    /// Drive the request to completion. The runner MUST emit at
114    /// least a final [`StreamEvent::Done`] event before returning,
115    /// for parity with native backends — consumers of the resulting
116    /// stream rely on `Done` as the terminal marker.
117    async fn run(
118        &self,
119        request: GenerateRequest,
120        emitter: EventEmitter,
121    ) -> Result<RunnerResult, RunnerError>;
122}
123
124fn runner_slot() -> &'static RwLock<Option<Arc<dyn InferenceRunner>>> {
125    static SLOT: OnceLock<RwLock<Option<Arc<dyn InferenceRunner>>>> = OnceLock::new();
126    SLOT.get_or_init(|| RwLock::new(None))
127}
128
129/// Install a process-wide inference runner. Re-registering overwrites
130/// any previous runner. Pass `None` to clear the slot (rarely useful
131/// in production, but handy in tests).
132pub fn set_inference_runner(runner: Option<Arc<dyn InferenceRunner>>) {
133    let mut guard = runner_slot()
134        .write()
135        .expect("inference runner slot poisoned");
136    *guard = runner;
137}
138
139/// Snapshot the currently registered runner (if any). Cheap clone of
140/// the `Arc`; safe to call from any thread.
141pub fn current_inference_runner() -> Option<Arc<dyn InferenceRunner>> {
142    runner_slot()
143        .read()
144        .expect("inference runner slot poisoned")
145        .clone()
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    struct EchoRunner;
153
154    #[async_trait::async_trait]
155    impl InferenceRunner for EchoRunner {
156        async fn run(
157            &self,
158            request: GenerateRequest,
159            emitter: EventEmitter,
160        ) -> Result<RunnerResult, RunnerError> {
161            let text = format!("echo:{}", request.prompt);
162            emitter.emit(StreamEvent::TextDelta(text.clone())).await;
163            emitter
164                .emit(StreamEvent::Done {
165                    text: text.clone(),
166                    tool_calls: vec![],
167                })
168                .await;
169            Ok(RunnerResult {
170                text,
171                tool_calls: vec![],
172            })
173        }
174    }
175
176    #[test]
177    fn slot_round_trips() {
178        // Use a fresh slot view: even though it's process-wide, the
179        // OnceLock stays initialised across tests in the same binary.
180        // Setting None first guarantees a clean state for this test.
181        set_inference_runner(None);
182        assert!(current_inference_runner().is_none());
183        set_inference_runner(Some(Arc::new(EchoRunner)));
184        assert!(current_inference_runner().is_some());
185        set_inference_runner(None);
186        assert!(current_inference_runner().is_none());
187    }
188
189    #[tokio::test]
190    async fn runner_can_emit_then_finish() {
191        let runner: Arc<dyn InferenceRunner> = Arc::new(EchoRunner);
192        let (tx, mut rx) = tokio::sync::mpsc::channel::<StreamEvent>(8);
193        let emitter = EventEmitter::new(tx);
194        let request = GenerateRequest {
195            prompt: "hi".into(),
196            ..Default::default()
197        };
198        let result = runner.run(request, emitter).await.unwrap();
199        assert_eq!(result.text, "echo:hi");
200        // Drain emitted events.
201        let mut got = Vec::new();
202        while let Ok(evt) =
203            tokio::time::timeout(std::time::Duration::from_millis(20), rx.recv()).await
204        {
205            match evt {
206                Some(e) => got.push(e),
207                None => break,
208            }
209        }
210        assert_eq!(got.len(), 2);
211        matches!(got[0], StreamEvent::TextDelta(_));
212        matches!(got[1], StreamEvent::Done { .. });
213    }
214}