Skip to main content

sparrow/runtime/
mod.rs

1use std::sync::Arc;
2use tokio::net::TcpListener;
3use tokio::sync::mpsc;
4use tokio_util::sync::CancellationToken;
5
6use crate::config::Config;
7use crate::engine::{Engine, Task};
8use crate::event::Event;
9use crate::memory::Memory;
10
11pub mod event_bus;
12pub mod ratelimit;
13pub mod recorder;
14pub mod scheduler;
15pub mod session;
16
17use event_bus::EventBus;
18use recorder::{FsRecorder, Recorder, RunInputs};
19use scheduler::{MemoryScheduler, Scheduler};
20
21// ─── Run request ────────────────────────────────────────────────────────────────
22
23#[derive(Debug, Clone)]
24pub struct RunRequest {
25    pub task: String,
26    pub agent: Option<String>,
27    pub autonomy: Option<crate::event::AutonomyLevel>,
28    pub budget: Option<f64>,
29}
30
31// ─── THE RUNTIME TRAIT ──────────────────────────────────────────────────────────
32
33#[async_trait::async_trait]
34pub trait Runtime: Send + Sync {
35    async fn submit(&self, req: RunRequest) -> anyhow::Result<String>;
36    fn subscribe_all(&self) -> tokio::sync::broadcast::Receiver<Event>;
37    async fn interrupt(&self, _run_id: &str, _msg: &str) -> anyhow::Result<()>;
38    async fn start(&self) -> anyhow::Result<()>;
39    async fn stop(&self) -> anyhow::Result<()>;
40    fn is_running(&self) -> bool;
41}
42
43// ─── Headless daemon implementation ─────────────────────────────────────────────
44
45pub struct SparrowRuntime {
46    engine: Arc<Engine>,
47    scheduler: Arc<MemoryScheduler>,
48    recorder: Arc<FsRecorder>,
49    event_bus: EventBus,
50    _memory: Arc<dyn Memory>,
51    config: Config,
52    running: std::sync::atomic::AtomicBool,
53    // Running tasks
54    active_runs: tokio::sync::Mutex<std::collections::HashMap<String, tokio::task::JoinHandle<()>>>,
55    cancellations: Arc<tokio::sync::Mutex<std::collections::HashMap<String, CancellationToken>>>,
56    /// Per-run injection senders. Populated on submit(), consumed by redirect().
57    injects:
58        Arc<tokio::sync::Mutex<std::collections::HashMap<String, mpsc::UnboundedSender<String>>>>,
59}
60
61impl SparrowRuntime {
62    pub fn new(
63        engine: Arc<Engine>,
64        scheduler: Arc<MemoryScheduler>,
65        recorder: Arc<FsRecorder>,
66        event_bus: EventBus,
67        memory: Arc<dyn Memory>,
68        config: Config,
69    ) -> Self {
70        Self {
71            engine,
72            scheduler,
73            recorder,
74            event_bus,
75            _memory: memory,
76            config,
77            running: std::sync::atomic::AtomicBool::new(false),
78            active_runs: tokio::sync::Mutex::new(std::collections::HashMap::new()),
79            cancellations: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
80            injects: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
81        }
82    }
83
84    /// Inject a user message into a running run mid-flight (§3.7).
85    /// Returns an error if no run with that id is active.
86    pub async fn redirect(&self, run_id: &str, msg: String) -> anyhow::Result<()> {
87        let injects = self.injects.lock().await;
88        match injects.get(run_id) {
89            Some(tx) => {
90                tx.send(msg)
91                    .map_err(|e| anyhow::anyhow!("inject channel closed: {}", e))?;
92                Ok(())
93            }
94            None => anyhow::bail!("No active run with id {}", run_id),
95        }
96    }
97
98    /// Spawn the cron tick loop
99    async fn cron_loop(&self) {
100        let scheduler = self.scheduler.clone();
101        let engine = self.engine.clone();
102        let event_bus = self.event_bus.clone();
103        let recorder = self.recorder.clone();
104
105        tokio::spawn(async move {
106            loop {
107                tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
108
109                let due_jobs = scheduler.tick().await;
110                for job in due_jobs {
111                    tracing::info!("Running scheduled job: {} ({})", job.id, job.task);
112
113                    let (tx, mut rx) = mpsc::unbounded_channel::<Event>();
114                    let task = Task {
115                        description: job.task.clone(),
116                        context: vec![],
117                    };
118
119                    let run_id = uuid::Uuid::new_v4().to_string();
120                    recorder.start_run(
121                        run_id.clone(),
122                        RunInputs {
123                            task: job.task.clone(),
124                            config_snapshot: serde_json::json!({}),
125                            model_id: "scheduled".into(),
126                            repo_head: None,
127                            timestamp: chrono::Utc::now().to_rfc3339(),
128                            agent: "scheduler".into(),
129                        },
130                    );
131
132                    let event_bus_clone = event_bus.clone();
133                    let recorder_clone = recorder.clone();
134                    let run_id_clone = run_id.clone();
135                    let engine_clone = engine.clone();
136
137                    tokio::spawn(async move {
138                        let engine_run_id = run_id_clone.clone();
139                        let engine_handle = tokio::spawn(async move {
140                            engine_clone
141                                .drive_with_run_id(task, tx, crate::event::RunId(engine_run_id))
142                                .await
143                        });
144
145                        while let Some(event) = rx.recv().await {
146                            recorder_clone.record(&event);
147                            event_bus_clone.publish(event);
148                        }
149
150                        if let Err(err) = engine_handle.await {
151                            tracing::error!("scheduled engine task failed: {}", err);
152                        }
153                        let _ = recorder_clone.finalize(&run_id_clone);
154                    });
155                }
156            }
157        });
158    }
159
160    /// Start a TCP socket for local API access
161    async fn serve_api(&self, addr: &str) -> anyhow::Result<()> {
162        let listener = TcpListener::bind(addr).await?;
163        tracing::info!("Runtime API listening on {}", addr);
164
165        let event_bus = self.event_bus.clone();
166
167        tokio::spawn(async move {
168            loop {
169                match listener.accept().await {
170                    Ok((mut stream, addr)) => {
171                        tracing::debug!("API connection from {}", addr);
172                        let mut rx = event_bus.subscribe_all();
173
174                        tokio::spawn(async move {
175                            use tokio::io::AsyncWriteExt;
176                            loop {
177                                match rx.recv().await {
178                                    Ok(event) => {
179                                        if !event.is_public() {
180                                            continue;
181                                        }
182                                        if let Ok(json) = serde_json::to_string(&event) {
183                                            let line = json + "\n";
184                                            if stream.write_all(line.as_bytes()).await.is_err() {
185                                                break;
186                                            }
187                                        }
188                                    }
189                                    Err(_) => break,
190                                }
191                            }
192                        });
193                    }
194                    Err(e) => {
195                        tracing::error!("Accept error: {}", e);
196                    }
197                }
198            }
199        });
200
201        Ok(())
202    }
203
204    /// Start a Unix domain socket for local API access (Linux/macOS).
205    /// Clients can `socat - UNIX-CONNECT:/path/to/socket` to receive NDJSON events.
206    #[cfg(unix)]
207    async fn serve_unix_socket(&self, path: &str) -> anyhow::Result<()> {
208        use tokio::net::UnixListener;
209
210        // Remove stale socket file from a previous run
211        let _ = std::fs::remove_file(path);
212
213        let listener = UnixListener::bind(path)?;
214        tracing::info!("Runtime Unix socket at {}", path);
215
216        // Restrict to owner only (rwx------) for security
217        #[cfg(target_os = "linux")]
218        {
219            use std::os::unix::fs::PermissionsExt;
220            let _ = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600));
221        }
222
223        let event_bus = self.event_bus.clone();
224
225        tokio::spawn(async move {
226            loop {
227                match listener.accept().await {
228                    Ok((mut stream, _)) => {
229                        tracing::debug!("Unix socket connection");
230                        let mut rx = event_bus.subscribe_all();
231                        tokio::spawn(async move {
232                            use tokio::io::AsyncWriteExt;
233                            loop {
234                                match rx.recv().await {
235                                    Ok(event) => {
236                                        if !event.is_public() {
237                                            continue;
238                                        }
239                                        if let Ok(json) = serde_json::to_string(&event) {
240                                            let line = json + "\n";
241                                            if stream.write_all(line.as_bytes()).await.is_err() {
242                                                break;
243                                            }
244                                        }
245                                    }
246                                    Err(_) => break,
247                                }
248                            }
249                        });
250                    }
251                    Err(e) => {
252                        tracing::error!("Unix socket accept error: {}", e);
253                    }
254                }
255            }
256        });
257
258        Ok(())
259    }
260
261    /// No-op stub on Windows / non-unix targets.
262    #[cfg(not(unix))]
263    async fn serve_unix_socket(&self, _path: &str) -> anyhow::Result<()> {
264        tracing::debug!("Unix socket not available on this platform; skipping.");
265        Ok(())
266    }
267}
268
269#[async_trait::async_trait]
270impl Runtime for SparrowRuntime {
271    async fn submit(&self, req: RunRequest) -> anyhow::Result<String> {
272        let run_id = uuid::Uuid::new_v4().to_string();
273        let (tx, mut rx) = mpsc::unbounded_channel();
274        let cancel_token = CancellationToken::new();
275
276        let task = Task {
277            description: req.task.clone(),
278            context: vec![],
279        };
280
281        self.recorder.start_run(
282            run_id.clone(),
283            RunInputs {
284                task: req.task,
285                config_snapshot: serde_json::json!({}),
286                model_id: "runtime".into(),
287                repo_head: None,
288                timestamp: chrono::Utc::now().to_rfc3339(),
289                agent: req.agent.unwrap_or_else(|| "sparrow".into()),
290            },
291        );
292
293        let engine = self.engine.clone();
294        let event_bus = self.event_bus.clone();
295        let recorder = self.recorder.clone();
296        let rid = run_id.clone();
297        let token = cancel_token.clone();
298        let cancellations = self.cancellations.clone();
299
300        // Inject channel: redirect() pushes user messages into the running engine
301        let (inject_tx, inject_rx) = mpsc::unbounded_channel::<String>();
302        self.injects.lock().await.insert(run_id.clone(), inject_tx);
303        let injects_map = self.injects.clone();
304
305        // RAII cleanup so that a panic anywhere inside the spawned task still
306        // releases the run_id from `cancellations` / `injects` — preventing the
307        // map from growing unboundedly with zombie keys.
308        struct RunCleanup {
309            rid: String,
310            cancellations:
311                Arc<tokio::sync::Mutex<std::collections::HashMap<String, CancellationToken>>>,
312            injects: Arc<
313                tokio::sync::Mutex<
314                    std::collections::HashMap<String, mpsc::UnboundedSender<String>>,
315                >,
316            >,
317            recorder: Arc<FsRecorder>,
318        }
319        impl Drop for RunCleanup {
320            fn drop(&mut self) {
321                // We can't .await in drop; use blocking_lock when not on a Tokio
322                // worker thread, or spawn a task that owns the cleanup. The
323                // cheapest correct option is to spawn a detached task — runtime
324                // is guaranteed alive because submit() runs inside it.
325                let rid = std::mem::take(&mut self.rid);
326                let cancellations = self.cancellations.clone();
327                let injects = self.injects.clone();
328                let recorder = self.recorder.clone();
329                tokio::spawn(async move {
330                    cancellations.lock().await.remove(&rid);
331                    injects.lock().await.remove(&rid);
332                    let _ = recorder.finalize(&rid);
333                });
334            }
335        }
336
337        let handle = tokio::spawn(async move {
338            let _guard = RunCleanup {
339                rid: rid.clone(),
340                cancellations,
341                injects: injects_map,
342                recorder: recorder.clone(),
343            };
344            let engine_rid = rid.clone();
345            let cancel_rid = rid.clone();
346            let cancel_tx = tx.clone();
347            let engine_handle = tokio::spawn(async move {
348                tokio::select! {
349                    result = engine.drive_with_inject(
350                        task,
351                        tx,
352                        crate::event::RunId(engine_rid),
353                        Some(inject_rx),
354                    ) => result,
355                    _ = token.cancelled() => {
356                        let _ = cancel_tx.send(Event::Error {
357                            run: crate::event::RunId(cancel_rid),
358                            message: "interrupted".into(),
359                        });
360                        Ok(crate::event::OutcomeSummary {
361                            status: "interrupted".into(),
362                            cost_usd: 0.0,
363                            tokens: crate::event::TokenUsage {
364                                input: 0,
365                                output: 0,
366                            },
367                            diffs: vec![],
368                        })
369                    }
370                }
371            });
372
373            while let Some(event) = rx.recv().await {
374                recorder.record(&event);
375                event_bus.publish(event);
376            }
377
378            if let Err(err) = engine_handle.await {
379                tracing::error!("runtime engine task failed: {}", err);
380            }
381            // `_guard` drops here on the happy path and on early returns / panics.
382        });
383
384        self.cancellations
385            .lock()
386            .await
387            .insert(run_id.clone(), cancel_token);
388        self.active_runs.lock().await.insert(run_id.clone(), handle);
389        Ok(run_id)
390    }
391
392    fn subscribe_all(&self) -> tokio::sync::broadcast::Receiver<Event> {
393        self.event_bus.subscribe_all()
394    }
395
396    async fn interrupt(&self, run_id: &str, msg: &str) -> anyhow::Result<()> {
397        tracing::info!("Interrupt requested for run {}: {}", run_id, msg);
398        if let Some(token) = self.cancellations.lock().await.get(run_id).cloned() {
399            token.cancel();
400            Ok(())
401        } else {
402            anyhow::bail!("No active run found for interrupt: {}", run_id)
403        }
404    }
405
406    async fn start(&self) -> anyhow::Result<()> {
407        if self.running.load(std::sync::atomic::Ordering::SeqCst) {
408            anyhow::bail!("Runtime is already running");
409        }
410        self.running
411            .store(true, std::sync::atomic::Ordering::SeqCst);
412
413        let api_addr = "127.0.0.1:9337";
414        self.cron_loop().await;
415        self.serve_api(api_addr).await?;
416
417        // Unix socket: resolves to ~/.local/state/sparrow/sparrow.sock (or XDG equivalent)
418        let socket_path = self
419            .config
420            .state_dir
421            .join("sparrow.sock")
422            .to_string_lossy()
423            .to_string();
424        if let Err(e) = self.serve_unix_socket(&socket_path).await {
425            tracing::warn!("Unix socket failed (non-fatal): {}", e);
426        } else {
427            tracing::info!("Runtime Unix socket at {}", socket_path);
428        }
429
430        tracing::info!("Runtime started. TCP API at {}", api_addr);
431        tracing::info!("Scheduled jobs active.");
432
433        Ok(())
434    }
435
436    async fn stop(&self) -> anyhow::Result<()> {
437        self.running
438            .store(false, std::sync::atomic::Ordering::SeqCst);
439        tracing::info!("Runtime stopped.");
440        Ok(())
441    }
442
443    fn is_running(&self) -> bool {
444        self.running.load(std::sync::atomic::Ordering::SeqCst)
445    }
446}