Skip to main content

sparrow/
extras.rs

1use crate::engine::{Engine, Task};
2use crate::event::Event;
3use crate::memory::{Fact, Memory};
4use std::sync::Arc;
5use tokio::sync::mpsc;
6
7// ─── Auto-distillation ─────────────────────────────────────────────────────────
8
9/// After a successful run, extract durable facts about the user
10/// from the conversation trajectory.
11/// §3.8: "after sessions, distill durable facts/preferences into identity/facts_about_user"
12pub struct Distiller;
13
14impl Distiller {
15    /// Analyze run events and extract facts
16    pub async fn distill(memory: &Arc<dyn Memory>, events: &[Event], _task_description: &str) {
17        let mut facts = Vec::new();
18
19        // Extract user preferences from tools used
20        let mut lang_hints = Vec::new();
21        let mut framework_hints = Vec::new();
22        let mut style_hints = Vec::new();
23
24        for event in events {
25            match event {
26                Event::ToolUseProposed { args, .. } => {
27                    if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
28                        if path.ends_with(".rs") {
29                            lang_hints.push("Rust".to_string());
30                        }
31                        if path.ends_with(".ts") || path.ends_with(".tsx") {
32                            lang_hints.push("TypeScript".to_string());
33                        }
34                        if path.ends_with(".py") {
35                            lang_hints.push("Python".to_string());
36                        }
37                        if path.ends_with(".go") {
38                            lang_hints.push("Go".to_string());
39                        }
40                        if path.ends_with(".js") || path.ends_with(".jsx") {
41                            lang_hints.push("JavaScript".to_string());
42                        }
43                    }
44                    if let Some(content) = args.get("content").and_then(|v| v.as_str()) {
45                        if content.contains("Cargo.toml") {
46                            framework_hints.push("Rust/Cargo".to_string());
47                        }
48                        if content.contains("package.json") {
49                            framework_hints.push("Node.js".to_string());
50                        }
51                        if content.contains("go.mod") {
52                            framework_hints.push("Go modules".to_string());
53                        }
54                    }
55                }
56                Event::ThinkingDelta { text, .. } => {
57                    if text.contains("refactor") {
58                        style_hints.push("prefers refactoring".to_string());
59                    }
60                    if text.contains("test") || text.contains("TDD") {
61                        style_hints.push("test-driven".to_string());
62                    }
63                }
64                _ => {}
65            }
66        }
67
68        // Deduplicate and save facts
69        lang_hints.sort();
70        lang_hints.dedup();
71        framework_hints.sort();
72        framework_hints.dedup();
73        style_hints.sort();
74        style_hints.dedup();
75
76        for lang in &lang_hints {
77            facts.push(Fact {
78                id: uuid::Uuid::new_v4().to_string(),
79                key: "user:language".into(),
80                value: lang.clone(),
81                created_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
82                updated_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
83            });
84        }
85        for fw in &framework_hints {
86            facts.push(Fact {
87                id: uuid::Uuid::new_v4().to_string(),
88                key: "user:framework".into(),
89                value: fw.clone(),
90                created_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
91                updated_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
92            });
93        }
94        for style in &style_hints {
95            facts.push(Fact {
96                id: uuid::Uuid::new_v4().to_string(),
97                key: "user:style".into(),
98                value: style.clone(),
99                created_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
100                updated_at: chrono::Utc::now().format("%Y-%m-%d").to_string(),
101            });
102        }
103
104        // Save deduplicated facts
105        let existing = memory.all_facts();
106        let existing_keys: Vec<&str> = existing.iter().map(|f| f.key.as_str()).collect();
107
108        for fact in facts {
109            if !existing_keys.contains(&fact.key.as_str()) {
110                let _ = memory.remember(fact);
111            }
112        }
113
114        if !lang_hints.is_empty() || !framework_hints.is_empty() {
115            tracing::info!(
116                "Distiller: extracted {} facts from session",
117                lang_hints.len() + framework_hints.len() + style_hints.len()
118            );
119        }
120    }
121}
122
123// ─── Lightweight deterministic embeddings ──────────────────────────────────────
124
125/// Lightweight semantic embeddings for repo memory.
126/// §3.8: "optional embeddings (per project)"
127#[derive(Debug, Clone)]
128pub struct Embeddings {
129    /// Stored text + normalized hashing-vector embedding.
130    pub vectors: Vec<(String, Vec<f64>)>,
131    dimensions: usize,
132}
133
134impl Embeddings {
135    pub const DEFAULT_DIMENSIONS: usize = 512;
136
137    pub fn new() -> Self {
138        Self {
139            vectors: Vec::new(),
140            dimensions: Self::DEFAULT_DIMENSIONS,
141        }
142    }
143
144    pub fn with_dimensions(dimensions: usize) -> Self {
145        Self {
146            vectors: Vec::new(),
147            dimensions: dimensions.max(16),
148        }
149    }
150
151    /// Build a deterministic hashing-vector embedding from text.
152    ///
153    /// This is intentionally local-first: no model/API key required, fixed
154    /// dimensions across documents, stable across sessions, and good enough for
155    /// lexical semantic recall in memory. It uses signed feature hashing with
156    /// unigram + adjacent bigram features, sublinear term frequency, and L2
157    /// normalization.
158    pub fn embed(&self, text: &str) -> Vec<f64> {
159        embed_with_dimensions(text, self.dimensions)
160    }
161
162    pub fn add(&mut self, text: &str) {
163        let clean = text.trim();
164        if clean.is_empty() {
165            return;
166        }
167        self.vectors.push((clean.to_string(), self.embed(clean)));
168    }
169
170    pub fn add_many<I, S>(&mut self, texts: I)
171    where
172        I: IntoIterator<Item = S>,
173        S: AsRef<str>,
174    {
175        for text in texts {
176            self.add(text.as_ref());
177        }
178    }
179
180    /// Find the most similar stored text to the query
181    pub fn search(&self, query: &str, k: usize) -> Vec<String> {
182        self.search_scored(query, k)
183            .into_iter()
184            .map(|(_, text)| text)
185            .collect()
186    }
187
188    pub fn search_scored(&self, query: &str, k: usize) -> Vec<(f64, String)> {
189        if k == 0 {
190            return Vec::new();
191        }
192        let q_embed = self.embed(query);
193        let mut scored: Vec<(f64, usize, &str)> = self
194            .vectors
195            .iter()
196            .enumerate()
197            .map(|(idx, (text, emb))| (cosine_sim(&q_embed, emb), idx, text.as_str()))
198            .collect();
199        scored.sort_by(|a, b| {
200            b.0.partial_cmp(&a.0)
201                .unwrap_or(std::cmp::Ordering::Equal)
202                .then(a.1.cmp(&b.1))
203        });
204        scored
205            .into_iter()
206            .take(k)
207            .filter(|(score, _, _)| *score > 0.0)
208            .map(|(score, _, text)| (score, text.to_string()))
209            .collect()
210    }
211
212    pub fn save_to_path(&self, path: impl AsRef<std::path::Path>) -> anyhow::Result<()> {
213        let snapshot = EmbeddingsSnapshot {
214            dimensions: self.dimensions,
215            texts: self.vectors.iter().map(|(text, _)| text.clone()).collect(),
216        };
217        let json = serde_json::to_string_pretty(&snapshot)?;
218        if let Some(parent) = path.as_ref().parent() {
219            std::fs::create_dir_all(parent)?;
220        }
221        std::fs::write(path, json)?;
222        Ok(())
223    }
224
225    pub fn load_from_path(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
226        let json = std::fs::read_to_string(path)?;
227        let snapshot: EmbeddingsSnapshot = serde_json::from_str(&json)?;
228        let mut index = Self::with_dimensions(snapshot.dimensions);
229        index.add_many(snapshot.texts);
230        Ok(index)
231    }
232}
233
234impl Default for Embeddings {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240#[derive(serde::Serialize, serde::Deserialize)]
241struct EmbeddingsSnapshot {
242    dimensions: usize,
243    texts: Vec<String>,
244}
245
246fn embed_with_dimensions(text: &str, dimensions: usize) -> Vec<f64> {
247    let mut vector = vec![0.0; dimensions.max(16)];
248    let tokens = tokenize(text);
249    for token in &tokens {
250        add_feature(&mut vector, token, 1.0);
251    }
252    for pair in tokens.windows(2) {
253        add_feature(&mut vector, &format!("{}__{}", pair[0], pair[1]), 1.35);
254    }
255    for value in &mut vector {
256        if *value != 0.0 {
257            *value = value.signum() * value.abs().ln_1p();
258        }
259    }
260    normalize(&mut vector);
261    vector
262}
263
264fn tokenize(text: &str) -> Vec<String> {
265    let mut tokens = Vec::new();
266    let mut current = String::new();
267    for ch in text.chars() {
268        if ch.is_alphanumeric() {
269            current.extend(ch.to_lowercase());
270        } else if !current.is_empty() {
271            tokens.push(std::mem::take(&mut current));
272        }
273    }
274    if !current.is_empty() {
275        tokens.push(current);
276    }
277    tokens
278}
279
280fn add_feature(vector: &mut [f64], feature: &str, weight: f64) {
281    let hash = fnv1a64(feature.as_bytes());
282    let idx = (hash as usize) % vector.len();
283    let sign = if hash & (1 << 63) == 0 { 1.0 } else { -1.0 };
284    vector[idx] += sign * weight;
285}
286
287fn fnv1a64(bytes: &[u8]) -> u64 {
288    let mut hash = 0xcbf29ce484222325u64;
289    for byte in bytes {
290        hash ^= *byte as u64;
291        hash = hash.wrapping_mul(0x100000001b3);
292    }
293    hash
294}
295
296fn normalize(vector: &mut [f64]) {
297    let norm = vector.iter().map(|v| v * v).sum::<f64>().sqrt();
298    if norm > 0.0 {
299        for value in vector {
300            *value /= norm;
301        }
302    }
303}
304
305fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
306    let len = a.len().min(b.len());
307    if len == 0 {
308        return 0.0;
309    }
310    let dot: f64 = a.iter().zip(b.iter()).take(len).map(|(x, y)| x * y).sum();
311    let norm_a: f64 = a.iter().take(len).map(|x| x * x).sum::<f64>().sqrt();
312    let norm_b: f64 = b.iter().take(len).map(|x| x * x).sum::<f64>().sqrt();
313    if norm_a == 0.0 || norm_b == 0.0 {
314        0.0
315    } else {
316        dot / (norm_a * norm_b)
317    }
318}
319
320// ─── Replay re-execute ──────────────────────────────────────────────────────────
321
322/// Re-execute a transcript against a chosen model.
323/// §3.15: "can re-execute against a chosen model"
324pub struct ReExecuter {
325    engine: Arc<Engine>,
326}
327
328impl ReExecuter {
329    pub fn new(engine: Arc<Engine>) -> Self {
330        Self { engine }
331    }
332
333    /// Re-execute from a transcript: send the original task to the engine
334    /// with the same parameters.
335    pub async fn re_execute(
336        &self,
337        transcript: &crate::runtime::recorder::Transcript,
338    ) -> anyhow::Result<crate::event::OutcomeSummary> {
339        let (tx, _rx) = mpsc::unbounded_channel::<Event>();
340        let task = Task {
341            description: transcript.inputs.task.clone(),
342            context: vec![],
343        };
344        self.engine.drive(task, tx).await
345    }
346}
347
348// ─── OAuth flow ─────────────────────────────────────────────────────────────────
349
350pub struct OAuthFlow;
351
352impl OAuthFlow {
353    /// Start a device code OAuth flow.
354    /// Accepts the endpoints and scope from the provider registry — no hardcoded list.
355    pub async fn start_device_flow(
356        device_endpoint: &str,
357        token_endpoint_hint: &str, // unused here, kept for symmetry
358        client_id: &str,
359        scope: &str,
360    ) -> anyhow::Result<(String, String, String)> {
361        let _ = token_endpoint_hint;
362        let client = reqwest::Client::new();
363        let resp: serde_json::Value = client
364            .post(device_endpoint)
365            .form(&[("client_id", client_id), ("scope", scope)])
366            .send()
367            .await?
368            .json()
369            .await?;
370
371        let verification_uri = resp["verification_uri"]
372            .as_str()
373            .or_else(|| resp["verification_url"].as_str())
374            .unwrap_or("")
375            .to_string();
376        let user_code = resp["user_code"].as_str().unwrap_or("").to_string();
377        let device_code = resp["device_code"].as_str().unwrap_or("").to_string();
378
379        if device_code.is_empty() {
380            anyhow::bail!("Device flow start failed — provider response: {}", resp);
381        }
382
383        Ok((verification_uri, user_code, device_code))
384    }
385
386    /// Poll for token completion using the provider token endpoint.
387    pub async fn poll_token(
388        token_endpoint: &str,
389        client_id: &str,
390        device_code: &str,
391        timeout_secs: u64,
392    ) -> anyhow::Result<String> {
393        let client = reqwest::Client::new();
394        let start = std::time::Instant::now();
395
396        loop {
397            if start.elapsed().as_secs() > timeout_secs {
398                anyhow::bail!("OAuth timed out after {}s", timeout_secs);
399            }
400
401            let resp: serde_json::Value = client
402                .post(token_endpoint)
403                .form(&[
404                    ("client_id", client_id),
405                    ("device_code", device_code),
406                    ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
407                ])
408                .send()
409                .await?
410                .json()
411                .await?;
412
413            if let Some(token) = resp["access_token"].as_str() {
414                return Ok(token.to_string());
415            }
416
417            match resp["error"].as_str() {
418                Some("authorization_pending") | Some("slow_down") => {
419                    tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
420                    continue;
421                }
422                Some(e) => anyhow::bail!("OAuth error: {}", e),
423                None => {
424                    tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
425                }
426            }
427        }
428    }
429}
430
431// ─── IBM Plex Mono reference ────────────────────────────────────────────────────
432
433/// §9.3: "IBM Plex Mono everywhere (TUI authenticity + web)"
434/// The font is not embedded in the binary; users install it system-wide.
435/// This constant provides the download URL and instructions.
436pub const IBM_PLEX_MONO_URL: &str =
437    "https://github.com/IBM/plex/releases/latest/download/IBM-Plex-Mono.zip";
438
439pub fn ibm_plex_install_instructions() -> String {
440    r#"IBM Plex Mono — recommended font for Sparrow TUI.
441
442Install:
443  Linux:   sudo apt install fonts-ibm-plex
444  macOS:   brew install font-ibm-plex
445  Windows: Download from https://github.com/IBM/plex/releases
446
447Then update your terminal to use "IBM Plex Mono" as the font.
448"#
449    .to_string()
450}
451
452// ─── Chat mode ──────────────────────────────────────────────────────────────────
453
454/// Interactive multi-turn chat loop.
455/// §4: "sparrow chat — interactive multi-turn (TUI/inline)"
456pub struct ChatSession {
457    engine: Arc<Engine>,
458    history: Vec<crate::provider::Msg>,
459    running: bool,
460}
461
462impl ChatSession {
463    pub fn new(engine: Arc<Engine>) -> Self {
464        Self {
465            engine,
466            history: Vec::new(),
467            running: true,
468        }
469    }
470
471    pub async fn run_interactive(&mut self) -> anyhow::Result<()> {
472        use std::io::{self, Write};
473
474        println!("═══ Sparrow Chat ═══");
475        println!("Type your message and press Enter. Type /exit to quit.");
476        println!();
477
478        while self.running {
479            print!("◆ you › ");
480            io::stdout().flush()?;
481
482            let mut input = String::new();
483            io::stdin().read_line(&mut input)?;
484            let input = input.trim().to_string();
485
486            if input.is_empty() {
487                continue;
488            }
489            if input == "/exit" || input == "/quit" {
490                break;
491            }
492
493            self.history.push(crate::provider::Msg {
494                role: "user".into(),
495                content: vec![crate::provider::ContentBlock::Text {
496                    text: input.clone(),
497                }],
498            });
499
500            let (tx, mut rx) = mpsc::unbounded_channel::<Event>();
501            let task = Task {
502                description: input.clone(),
503                context: self.history.clone(),
504            };
505
506            let engine = self.engine.clone();
507            let handle = tokio::spawn(async move { engine.drive(task, tx).await });
508
509            while let Some(event) = rx.recv().await {
510                match &event {
511                    Event::ThinkingDelta { text, .. } => {
512                        print!("{}", text);
513                        io::stdout().flush()?;
514                    }
515                    Event::RunFinished { outcome, .. } => {
516                        println!("\n── {} | ${:.4} ──", outcome.status, outcome.cost_usd);
517                    }
518                    Event::Error { message, .. } => {
519                        eprintln!("\nError: {}", message);
520                    }
521                    _ => {}
522                }
523            }
524
525            match handle.await? {
526                Ok(outcome) => {
527                    self.history.push(crate::provider::Msg {
528                        role: "assistant".into(),
529                        content: vec![crate::provider::ContentBlock::Text {
530                            text: format!("[{}]", outcome.status),
531                        }],
532                    });
533                }
534                Err(e) => {
535                    eprintln!("Chat error: {}", e);
536                }
537            }
538            println!();
539        }
540
541        Ok(())
542    }
543}
544
545// ─── Configurable pipeline ──────────────────────────────────────────────────────
546
547/// Allow users to define custom swarm pipeline graphs.
548/// §3.11: "Configurable: number of agents, which model per role, pipeline graph."
549#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
550pub struct PipelineConfig {
551    pub name: String,
552    pub steps: Vec<PipelineStep>,
553    pub max_reworks: u32,
554}
555
556#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
557pub struct PipelineStep {
558    pub role: String,
559    pub model_preference: Option<String>,
560    pub prompt_override: Option<String>,
561    pub depends_on: Vec<String>,
562}
563
564impl PipelineConfig {
565    pub fn default_pipeline() -> Self {
566        Self {
567            name: "planner-coder-verifier".into(),
568            steps: vec![
569                PipelineStep {
570                    role: "planner".into(),
571                    model_preference: None,
572                    prompt_override: None,
573                    depends_on: vec![],
574                },
575                PipelineStep {
576                    role: "coder".into(),
577                    model_preference: None,
578                    prompt_override: None,
579                    depends_on: vec!["planner".into()],
580                },
581                PipelineStep {
582                    role: "verifier".into(),
583                    model_preference: None,
584                    prompt_override: None,
585                    depends_on: vec!["coder".into()],
586                },
587            ],
588            max_reworks: 3,
589        }
590    }
591
592    pub fn validate(&self) -> anyhow::Result<()> {
593        if self.steps.is_empty() {
594            anyhow::bail!("Pipeline must have at least one step");
595        }
596        for step in &self.steps {
597            for dep in &step.depends_on {
598                if !self.steps.iter().any(|s| s.role == *dep) {
599                    anyhow::bail!("Step '{}' depends on unknown role '{}'", step.role, dep);
600                }
601            }
602        }
603        Ok(())
604    }
605
606    pub fn from_toml(content: &str) -> anyhow::Result<Self> {
607        Ok(toml::from_str(content)?)
608    }
609
610    pub fn to_toml(&self) -> anyhow::Result<String> {
611        Ok(toml::to_string_pretty(self)?)
612    }
613}
614
615// ─── Profile isolation ──────────────────────────────────────────────────────────
616
617/// Full profile isolation: separate config, memory, agents per profile.
618/// §4: "sparrow profile <create|list|use> — multi-instance profiles"
619pub struct Profile {
620    pub name: String,
621    pub config_dir: std::path::PathBuf,
622    pub state_dir: std::path::PathBuf,
623    pub config: crate::config::Config,
624    pub memory: Arc<dyn Memory>,
625}
626
627impl Profile {
628    pub fn load(name: &str) -> anyhow::Result<Self> {
629        let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
630        let base_state = dirs::state_dir().unwrap_or_default().join("sparrow");
631
632        let config_dir = base_config.join("profiles").join(name);
633        let state_dir = base_state.join("profiles").join(name);
634
635        std::fs::create_dir_all(&config_dir)?;
636        std::fs::create_dir_all(&state_dir)?;
637
638        let config = if config_dir.join("config.toml").exists() {
639            let content = std::fs::read_to_string(config_dir.join("config.toml"))?;
640            toml::from_str(&content)?
641        } else {
642            // Inherit from default config if available
643            let default = base_config.join("config.toml");
644            if default.exists() {
645                let content = std::fs::read_to_string(&default)?;
646                toml::from_str(&content)?
647            } else {
648                crate::config::Config {
649                    defaults: Default::default(),
650                    routing: Default::default(),
651                    budget: Default::default(),
652                    providers: Default::default(),
653                    surfaces: Default::default(),
654                    skills: Default::default(),
655                    permissions: Default::default(),
656                    hooks: Default::default(),
657                    theme: "captain".into(),
658                    config_dir: config_dir.clone(),
659                    state_dir: state_dir.clone(),
660                    forced_model: None,
661                }
662            }
663        };
664
665        let memory: Arc<dyn Memory> = Arc::new(crate::memory::SqliteMemory::open(
666            &state_dir.join("profile.db"),
667        )?);
668
669        Ok(Self {
670            name: name.to_string(),
671            config_dir,
672            state_dir,
673            config,
674            memory,
675        })
676    }
677
678    pub fn create(name: &str) -> anyhow::Result<()> {
679        let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
680        let config_dir = base_config.join("profiles").join(name);
681        std::fs::create_dir_all(&config_dir)?;
682
683        // Copy default config
684        let default = base_config.join("config.toml");
685        if default.exists() {
686            std::fs::copy(&default, config_dir.join("config.toml"))?;
687        }
688
689        let base_state = dirs::state_dir().unwrap_or_default().join("sparrow");
690        std::fs::create_dir_all(base_state.join("profiles").join(name))?;
691
692        Ok(())
693    }
694
695    pub fn list() -> Vec<String> {
696        let base_config = dirs::config_dir().unwrap_or_default().join("sparrow");
697        let profiles_dir = base_config.join("profiles");
698        let mut names = Vec::new();
699        if let Ok(entries) = std::fs::read_dir(&profiles_dir) {
700            for entry in entries.flatten() {
701                if entry.path().is_dir() {
702                    if let Some(name) = entry.file_name().to_str() {
703                        names.push(name.to_string());
704                    }
705                }
706            }
707        }
708        names.sort();
709        names
710    }
711}