Skip to main content

sparrow/runtime/
mod.rs

1use std::sync::Arc;
2use tokio::net::TcpListener;
3use tokio::sync::mpsc;
4use tokio_util::sync::CancellationToken;
5
6use crate::config::Config;
7use crate::engine::{Engine, Task};
8use crate::event::Event;
9use crate::memory::Memory;
10
11pub mod event_bus;
12pub mod ratelimit;
13pub mod recorder;
14pub mod scheduler;
15pub mod session;
16
17use event_bus::EventBus;
18use recorder::{FsRecorder, Recorder, RunInputs};
19use scheduler::{MemoryScheduler, Scheduler};
20
21// ─── Run request ────────────────────────────────────────────────────────────────
22
23#[derive(Debug, Clone)]
24pub struct RunRequest {
25    pub task: String,
26    pub agent: Option<String>,
27    pub autonomy: Option<crate::event::AutonomyLevel>,
28    pub budget: Option<f64>,
29}
30
31// ─── THE RUNTIME TRAIT ──────────────────────────────────────────────────────────
32
33#[async_trait::async_trait]
34pub trait Runtime: Send + Sync {
35    async fn submit(&self, req: RunRequest) -> anyhow::Result<String>;
36    fn subscribe_all(&self) -> tokio::sync::broadcast::Receiver<Event>;
37    async fn interrupt(&self, _run_id: &str, _msg: &str) -> anyhow::Result<()>;
38    async fn start(&self) -> anyhow::Result<()>;
39    async fn stop(&self) -> anyhow::Result<()>;
40    fn is_running(&self) -> bool;
41}
42
43// ─── Headless daemon implementation ─────────────────────────────────────────────
44
45pub struct SparrowRuntime {
46    engine: Arc<Engine>,
47    scheduler: Arc<MemoryScheduler>,
48    recorder: Arc<FsRecorder>,
49    event_bus: EventBus,
50    _memory: Arc<dyn Memory>,
51    config: Config,
52    running: std::sync::atomic::AtomicBool,
53    // Running tasks
54    active_runs: tokio::sync::Mutex<std::collections::HashMap<String, tokio::task::JoinHandle<()>>>,
55    cancellations: Arc<tokio::sync::Mutex<std::collections::HashMap<String, CancellationToken>>>,
56    /// Per-run injection senders. Populated on submit(), consumed by redirect().
57    injects:
58        Arc<tokio::sync::Mutex<std::collections::HashMap<String, mpsc::UnboundedSender<String>>>>,
59}
60
61impl SparrowRuntime {
62    pub fn new(
63        engine: Arc<Engine>,
64        scheduler: Arc<MemoryScheduler>,
65        recorder: Arc<FsRecorder>,
66        event_bus: EventBus,
67        memory: Arc<dyn Memory>,
68        config: Config,
69    ) -> Self {
70        Self {
71            engine,
72            scheduler,
73            recorder,
74            event_bus,
75            _memory: memory,
76            config,
77            running: std::sync::atomic::AtomicBool::new(false),
78            active_runs: tokio::sync::Mutex::new(std::collections::HashMap::new()),
79            cancellations: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
80            injects: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
81        }
82    }
83
84    /// Inject a user message into a running run mid-flight (§3.7).
85    /// Returns an error if no run with that id is active.
86    pub async fn redirect(&self, run_id: &str, msg: String) -> anyhow::Result<()> {
87        let injects = self.injects.lock().await;
88        match injects.get(run_id) {
89            Some(tx) => {
90                tx.send(msg)
91                    .map_err(|e| anyhow::anyhow!("inject channel closed: {}", e))?;
92                Ok(())
93            }
94            None => anyhow::bail!("No active run with id {}", run_id),
95        }
96    }
97
98    /// Spawn the cron tick loop
99    async fn cron_loop(&self) {
100        let scheduler = self.scheduler.clone();
101        let engine = self.engine.clone();
102        let event_bus = self.event_bus.clone();
103        let recorder = self.recorder.clone();
104
105        tokio::spawn(async move {
106            loop {
107                tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
108
109                let due_jobs = scheduler.tick().await;
110                for job in due_jobs {
111                    tracing::info!("Running scheduled job: {} ({})", job.id, job.task);
112
113                    let (tx, mut rx) = mpsc::unbounded_channel::<Event>();
114                    let task = Task {
115                        description: job.task.clone(),
116                        context: vec![],
117                    };
118
119                    let run_id = uuid::Uuid::new_v4().to_string();
120                    recorder.start_run(
121                        run_id.clone(),
122                        RunInputs {
123                            task: job.task.clone(),
124                            config_snapshot: serde_json::json!({}),
125                            model_id: "scheduled".into(),
126                            repo_head: None,
127                            timestamp: chrono::Utc::now().to_rfc3339(),
128                            agent: "scheduler".into(),
129                        },
130                    );
131
132                    let event_bus_clone = event_bus.clone();
133                    let recorder_clone = recorder.clone();
134                    let run_id_clone = run_id.clone();
135                    let engine_clone = engine.clone();
136
137                    tokio::spawn(async move {
138                        let engine_run_id = run_id_clone.clone();
139                        let engine_handle = tokio::spawn(async move {
140                            engine_clone
141                                .drive_with_run_id(task, tx, crate::event::RunId(engine_run_id))
142                                .await
143                        });
144
145                        while let Some(event) = rx.recv().await {
146                            recorder_clone.record(&event);
147                            event_bus_clone.publish(event);
148                        }
149
150                        if let Err(err) = engine_handle.await {
151                            tracing::error!("scheduled engine task failed: {}", err);
152                        }
153                        let _ = recorder_clone.finalize(&run_id_clone);
154                    });
155                }
156            }
157        });
158    }
159
160    /// Start a TCP socket for local API access
161    async fn serve_api(&self, addr: &str) -> anyhow::Result<()> {
162        let listener = TcpListener::bind(addr).await?;
163        tracing::info!("Runtime API listening on {}", addr);
164
165        let event_bus = self.event_bus.clone();
166
167        tokio::spawn(async move {
168            loop {
169                match listener.accept().await {
170                    Ok((mut stream, addr)) => {
171                        tracing::debug!("API connection from {}", addr);
172                        let mut rx = event_bus.subscribe_all();
173
174                        tokio::spawn(async move {
175                            use tokio::io::AsyncWriteExt;
176                            loop {
177                                match rx.recv().await {
178                                    Ok(event) => {
179                                        if !event.is_public() {
180                                            continue;
181                                        }
182                                        if let Ok(json) = serde_json::to_string(&event) {
183                                            let line = json + "\n";
184                                            if stream.write_all(line.as_bytes()).await.is_err() {
185                                                break;
186                                            }
187                                        }
188                                    }
189                                    Err(_) => break,
190                                }
191                            }
192                        });
193                    }
194                    Err(e) => {
195                        tracing::error!("Accept error: {}", e);
196                    }
197                }
198            }
199        });
200
201        Ok(())
202    }
203
204    /// Start a Unix domain socket for local API access (Linux/macOS).
205    /// Clients can `socat - UNIX-CONNECT:/path/to/socket` to receive NDJSON events.
206    #[cfg(unix)]
207    async fn serve_unix_socket(&self, path: &str) -> anyhow::Result<()> {
208        use tokio::net::UnixListener;
209
210        // Remove stale socket file from a previous run
211        let _ = std::fs::remove_file(path);
212
213        let listener = UnixListener::bind(path)?;
214        tracing::info!("Runtime Unix socket at {}", path);
215
216        // Restrict to owner only (rwx------) for security
217        #[cfg(target_os = "linux")]
218        {
219            use std::os::unix::fs::PermissionsExt;
220            let _ = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600));
221        }
222
223        let event_bus = self.event_bus.clone();
224
225        tokio::spawn(async move {
226            loop {
227                match listener.accept().await {
228                    Ok((mut stream, _)) => {
229                        tracing::debug!("Unix socket connection");
230                        let mut rx = event_bus.subscribe_all();
231                        tokio::spawn(async move {
232                            use tokio::io::AsyncWriteExt;
233                            loop {
234                                match rx.recv().await {
235                                    Ok(event) => {
236                                        if !event.is_public() {
237                                            continue;
238                                        }
239                                        if let Ok(json) = serde_json::to_string(&event) {
240                                            let line = json + "\n";
241                                            if stream.write_all(line.as_bytes()).await.is_err() {
242                                                break;
243                                            }
244                                        }
245                                    }
246                                    Err(_) => break,
247                                }
248                            }
249                        });
250                    }
251                    Err(e) => {
252                        tracing::error!("Unix socket accept error: {}", e);
253                    }
254                }
255            }
256        });
257
258        Ok(())
259    }
260
261    /// No-op stub on Windows / non-unix targets.
262    #[cfg(not(unix))]
263    async fn serve_unix_socket(&self, _path: &str) -> anyhow::Result<()> {
264        tracing::debug!("Unix socket not available on this platform; skipping.");
265        Ok(())
266    }
267
268    /// Public accessor so surfaces can publish events (e.g. update
269    /// notifications). Lives on the concrete type, not on the `Runtime`
270    /// trait — adding it to the trait would force every Runtime impl to
271    /// expose its internal EventBus.
272    pub fn event_bus(&self) -> &EventBus {
273        &self.event_bus
274    }
275}
276
277#[async_trait::async_trait]
278impl Runtime for SparrowRuntime {
279    async fn submit(&self, req: RunRequest) -> anyhow::Result<String> {
280        let run_id = uuid::Uuid::new_v4().to_string();
281        let (tx, mut rx) = mpsc::unbounded_channel();
282        let cancel_token = CancellationToken::new();
283
284        let task = Task {
285            description: req.task.clone(),
286            context: vec![],
287        };
288
289        self.recorder.start_run(
290            run_id.clone(),
291            RunInputs {
292                task: req.task,
293                config_snapshot: serde_json::json!({}),
294                model_id: "runtime".into(),
295                repo_head: None,
296                timestamp: chrono::Utc::now().to_rfc3339(),
297                agent: req.agent.unwrap_or_else(|| "sparrow".into()),
298            },
299        );
300
301        let engine = self.engine.clone();
302        let event_bus = self.event_bus.clone();
303        let recorder = self.recorder.clone();
304        let rid = run_id.clone();
305        let token = cancel_token.clone();
306        let cancellations = self.cancellations.clone();
307
308        // Inject channel: redirect() pushes user messages into the running engine
309        let (inject_tx, inject_rx) = mpsc::unbounded_channel::<String>();
310        self.injects.lock().await.insert(run_id.clone(), inject_tx);
311        let injects_map = self.injects.clone();
312
313        // RAII cleanup so that a panic anywhere inside the spawned task still
314        // releases the run_id from `cancellations` / `injects` — preventing the
315        // map from growing unboundedly with zombie keys.
316        struct RunCleanup {
317            rid: String,
318            cancellations:
319                Arc<tokio::sync::Mutex<std::collections::HashMap<String, CancellationToken>>>,
320            injects: Arc<
321                tokio::sync::Mutex<
322                    std::collections::HashMap<String, mpsc::UnboundedSender<String>>,
323                >,
324            >,
325            recorder: Arc<FsRecorder>,
326        }
327        impl Drop for RunCleanup {
328            fn drop(&mut self) {
329                // We can't .await in drop; use blocking_lock when not on a Tokio
330                // worker thread, or spawn a task that owns the cleanup. The
331                // cheapest correct option is to spawn a detached task — runtime
332                // is guaranteed alive because submit() runs inside it.
333                let rid = std::mem::take(&mut self.rid);
334                let cancellations = self.cancellations.clone();
335                let injects = self.injects.clone();
336                let recorder = self.recorder.clone();
337                tokio::spawn(async move {
338                    cancellations.lock().await.remove(&rid);
339                    injects.lock().await.remove(&rid);
340                    let _ = recorder.finalize(&rid);
341                });
342            }
343        }
344
345        let handle = tokio::spawn(async move {
346            let _guard = RunCleanup {
347                rid: rid.clone(),
348                cancellations,
349                injects: injects_map,
350                recorder: recorder.clone(),
351            };
352            let engine_rid = rid.clone();
353            let cancel_rid = rid.clone();
354            let cancel_tx = tx.clone();
355            let engine_handle = tokio::spawn(async move {
356                tokio::select! {
357                    result = engine.drive_with_inject(
358                        task,
359                        tx,
360                        crate::event::RunId(engine_rid),
361                        Some(inject_rx),
362                    ) => result,
363                    _ = token.cancelled() => {
364                        let _ = cancel_tx.send(Event::Error {
365                            run: crate::event::RunId(cancel_rid),
366                            message: "interrupted".into(),
367                        });
368                        Ok(crate::event::OutcomeSummary {
369                            status: "interrupted".into(),
370                            cost_usd: 0.0,
371                            tokens: crate::event::TokenUsage {
372                                input: 0,
373                                output: 0,
374                            },
375                            diffs: vec![],
376                            cost_comparison: String::new(),
377                            duration_ms: None,
378                        })
379                    }
380                }
381            });
382
383            while let Some(event) = rx.recv().await {
384                recorder.record(&event);
385                event_bus.publish(event);
386            }
387
388            if let Err(err) = engine_handle.await {
389                tracing::error!("runtime engine task failed: {}", err);
390            }
391            // `_guard` drops here on the happy path and on early returns / panics.
392        });
393
394        self.cancellations
395            .lock()
396            .await
397            .insert(run_id.clone(), cancel_token);
398        self.active_runs.lock().await.insert(run_id.clone(), handle);
399        Ok(run_id)
400    }
401
402    fn subscribe_all(&self) -> tokio::sync::broadcast::Receiver<Event> {
403        self.event_bus.subscribe_all()
404    }
405
406    async fn interrupt(&self, run_id: &str, msg: &str) -> anyhow::Result<()> {
407        tracing::info!("Interrupt requested for run {}: {}", run_id, msg);
408        if let Some(token) = self.cancellations.lock().await.get(run_id).cloned() {
409            token.cancel();
410            Ok(())
411        } else {
412            anyhow::bail!("No active run found for interrupt: {}", run_id)
413        }
414    }
415
416    async fn start(&self) -> anyhow::Result<()> {
417        if self.running.load(std::sync::atomic::Ordering::SeqCst) {
418            anyhow::bail!("Runtime is already running");
419        }
420        self.running
421            .store(true, std::sync::atomic::Ordering::SeqCst);
422
423        let api_addr = "127.0.0.1:9337";
424        self.cron_loop().await;
425        self.serve_api(api_addr).await?;
426
427        // Unix socket: resolves to ~/.local/state/sparrow/sparrow.sock (or XDG equivalent)
428        let socket_path = self
429            .config
430            .state_dir
431            .join("sparrow.sock")
432            .to_string_lossy()
433            .to_string();
434        if let Err(e) = self.serve_unix_socket(&socket_path).await {
435            tracing::warn!("Unix socket failed (non-fatal): {}", e);
436        } else {
437            tracing::info!("Runtime Unix socket at {}", socket_path);
438        }
439
440        tracing::info!("Runtime started. TCP API at {}", api_addr);
441        tracing::info!("Scheduled jobs active.");
442
443        Ok(())
444    }
445
446    async fn stop(&self) -> anyhow::Result<()> {
447        self.running
448            .store(false, std::sync::atomic::Ordering::SeqCst);
449        tracing::info!("Runtime stopped.");
450        Ok(())
451    }
452
453    fn is_running(&self) -> bool {
454        self.running.load(std::sync::atomic::Ordering::SeqCst)
455    }
456}