sparrow-cli 0.5.1

A local-first Rust agent cockpit — route, run, replay, rewind
Documentation
//! Multimodal tools: image generation + text-to-speech (§15).
//!
//! Both target OpenAI-compatible endpoints so any provider that exposes
//! `/images/generations` or `/audio/speech` works. The user supplies the key
//! (env var, resolved at call time). No fake success: a missing key or a
//! non-2xx response returns a real error.

use async_trait::async_trait;
use serde_json::json;

use super::{Tool, ToolCtx, ToolResult};
use crate::event::{Block, RiskLevel};

fn resolve_key(env_names: &[&str]) -> Option<String> {
    for name in env_names {
        if let Ok(v) = std::env::var(name) {
            if !v.trim().is_empty() {
                return Some(v);
            }
        }
    }
    None
}

// ─── Image generation ─────────────────────────────────────────────────────────

/// Generate an image from a prompt via an OpenAI-compatible images endpoint.
pub struct ImageGen {
    base_url: String,
    model: String,
}

impl ImageGen {
    pub fn new() -> Self {
        Self {
            base_url: std::env::var("IMAGE_API_BASE")
                .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
            model: std::env::var("IMAGE_MODEL").unwrap_or_else(|_| "gpt-image-1".into()),
        }
    }
}

#[async_trait]
impl Tool for ImageGen {
    fn name(&self) -> &str {
        "image_generate"
    }
    fn description(&self) -> &str {
        "Generate an image from a text prompt. Saves a PNG into the workspace and returns its path."
    }
    fn schema(&self) -> serde_json::Value {
        json!({
            "type": "object",
            "properties": {
                "prompt": { "type": "string", "description": "Image description" },
                "filename": { "type": "string", "description": "Output filename (default: generated.png)" },
                "size": { "type": "string", "description": "e.g. 1024x1024" }
            },
            "required": ["prompt"]
        })
    }
    fn risk(&self) -> RiskLevel {
        RiskLevel::Network
    }
    async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
        let Some(key) = resolve_key(&["IMAGE_API_KEY", "OPENAI_API_KEY"]) else {
            return Ok(ToolResult::error(
                "No image API key. Set IMAGE_API_KEY or OPENAI_API_KEY.",
            ));
        };
        let prompt = args["prompt"].as_str().unwrap_or("");
        let size = args["size"].as_str().unwrap_or("1024x1024");
        let filename = args["filename"].as_str().unwrap_or("generated.png");

        let endpoint = format!("{}/images/generations", self.base_url.trim_end_matches('/'));
        if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
            return Ok(ToolResult::error(format!(
                "Refused IMAGE_API_BASE ({}): {}",
                why, endpoint
            )));
        }
        let client = reqwest::Client::new();
        let resp = client
            .post(&endpoint)
            .bearer_auth(&key)
            .json(&json!({
                "model": self.model,
                "prompt": prompt,
                "size": size,
                "n": 1,
                "response_format": "b64_json"
            }))
            .send()
            .await?;
        if !resp.status().is_success() {
            let status = resp.status();
            let body = resp.text().await.unwrap_or_default();
            return Ok(ToolResult::error(format!(
                "image API error {}: {}",
                status, body
            )));
        }
        let value: serde_json::Value = resp.json().await?;
        let b64 = value["data"][0]["b64_json"].as_str();
        let url = value["data"][0]["url"].as_str();

        if let Some(b64) = b64 {
            let bytes = base64_decode::decode(b64)
                .map_err(|e| anyhow::anyhow!("invalid base64 image: {}", e))?;
            let path = super::resolve_workspace_path(&ctx.workspace_root, filename)?;
            std::fs::write(&path, &bytes)?;
            Ok(ToolResult::ok(vec![Block::Text(format!(
                "image saved to {} ({} bytes)",
                path.display(),
                bytes.len()
            ))]))
        } else if let Some(url) = url {
            Ok(ToolResult::ok(vec![Block::Text(format!(
                "image generated: {}",
                url
            ))]))
        } else {
            Ok(ToolResult::error("image API returned no data"))
        }
    }
}

// ─── Text to speech ─────────────────────────────────────────────────────────

pub struct Tts {
    base_url: String,
    model: String,
}

impl Tts {
    pub fn new() -> Self {
        Self {
            base_url: std::env::var("TTS_API_BASE")
                .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
            model: std::env::var("TTS_MODEL").unwrap_or_else(|_| "gpt-4o-mini-tts".into()),
        }
    }
}

#[async_trait]
impl Tool for Tts {
    fn name(&self) -> &str {
        "text_to_speech"
    }
    fn description(&self) -> &str {
        "Synthesize speech from text via an OpenAI-compatible /audio/speech endpoint. Saves an audio file into the workspace."
    }
    fn schema(&self) -> serde_json::Value {
        json!({
            "type": "object",
            "properties": {
                "text": { "type": "string" },
                "voice": { "type": "string", "description": "e.g. alloy" },
                "filename": { "type": "string", "description": "default: speech.mp3" }
            },
            "required": ["text"]
        })
    }
    fn risk(&self) -> RiskLevel {
        RiskLevel::Network
    }
    async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
        let Some(key) = resolve_key(&["TTS_API_KEY", "OPENAI_API_KEY"]) else {
            return Ok(ToolResult::error(
                "No TTS API key. Set TTS_API_KEY or OPENAI_API_KEY.",
            ));
        };
        let text = args["text"].as_str().unwrap_or("");
        let voice = args["voice"].as_str().unwrap_or("alloy");
        let filename = args["filename"].as_str().unwrap_or("speech.mp3");

        let endpoint = format!("{}/audio/speech", self.base_url.trim_end_matches('/'));
        if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
            return Ok(ToolResult::error(format!(
                "Refused TTS_API_BASE ({}): {}",
                why, endpoint
            )));
        }
        let client = reqwest::Client::new();
        let resp = client
            .post(&endpoint)
            .bearer_auth(&key)
            .json(&json!({ "model": self.model, "input": text, "voice": voice }))
            .send()
            .await?;
        if !resp.status().is_success() {
            let status = resp.status();
            let body = resp.text().await.unwrap_or_default();
            return Ok(ToolResult::error(format!(
                "tts API error {}: {}",
                status, body
            )));
        }
        let bytes = resp.bytes().await?;
        let path = super::resolve_workspace_path(&ctx.workspace_root, filename)?;
        std::fs::write(&path, &bytes)?;
        Ok(ToolResult::ok(vec![Block::Text(format!(
            "audio saved to {} ({} bytes)",
            path.display(),
            bytes.len()
        ))]))
    }
}

// ─── Speech to text (Transcribe) ────────────────────────────────────────────────
//
// Voice-mode building block: posts a workspace audio file to an OpenAI-compatible
// `/audio/transcriptions` endpoint and returns the transcript as text. Missing
// key or non-2xx response is an HONEST error — never a fake success.

pub struct Transcribe {
    base_url: String,
    model: String,
}

impl Transcribe {
    pub fn new() -> Self {
        Self {
            base_url: std::env::var("TRANSCRIBE_API_BASE")
                .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
            model: std::env::var("TRANSCRIBE_MODEL").unwrap_or_else(|_| "whisper-1".into()),
        }
    }
}

#[async_trait]
impl Tool for Transcribe {
    fn name(&self) -> &str {
        "transcribe"
    }
    fn description(&self) -> &str {
        "Transcribe an audio file in the workspace to text via an OpenAI-compatible /audio/transcriptions endpoint."
    }
    fn schema(&self) -> serde_json::Value {
        json!({
            "type": "object",
            "properties": {
                "path": { "type": "string", "description": "Workspace-relative path to the audio file" },
                "language": { "type": "string", "description": "Optional ISO-639-1 language hint" }
            },
            "required": ["path"]
        })
    }
    fn risk(&self) -> RiskLevel {
        RiskLevel::Network
    }
    async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
        let Some(key) = resolve_key(&["TRANSCRIBE_API_KEY", "OPENAI_API_KEY"]) else {
            return Ok(ToolResult::error(
                "No transcription API key. Set TRANSCRIBE_API_KEY or OPENAI_API_KEY.",
            ));
        };
        let path = args["path"].as_str().unwrap_or("");
        if path.is_empty() {
            return Ok(ToolResult::error("transcribe: missing 'path' argument"));
        }
        let full = super::resolve_workspace_path(&ctx.workspace_root, path)?;
        if !full.exists() {
            return Ok(ToolResult::error(format!("audio file not found: {}", path)));
        }
        let bytes = std::fs::read(&full)?;
        let filename = full
            .file_name()
            .map(|s| s.to_string_lossy().to_string())
            .unwrap_or_else(|| "audio.bin".into());
        let mime = mime_guess::from_path(&full)
            .first_or_octet_stream()
            .to_string();

        let part = reqwest::multipart::Part::bytes(bytes)
            .file_name(filename)
            .mime_str(&mime)
            .unwrap_or_else(|_| reqwest::multipart::Part::text("")); // mime parse rarely fails
        let mut form = reqwest::multipart::Form::new()
            .text("model", self.model.clone())
            .part("file", part);
        if let Some(lang) = args["language"].as_str() {
            if !lang.is_empty() {
                form = form.text("language", lang.to_string());
            }
        }

        let endpoint = format!(
            "{}/audio/transcriptions",
            self.base_url.trim_end_matches('/')
        );
        if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
            return Ok(ToolResult::error(format!(
                "Refused TRANSCRIBE_API_BASE ({}): {}",
                why, endpoint
            )));
        }
        let client = reqwest::Client::new();
        let resp = client
            .post(&endpoint)
            .bearer_auth(&key)
            .multipart(form)
            .send()
            .await?;
        if !resp.status().is_success() {
            let status = resp.status();
            let body = resp.text().await.unwrap_or_default();
            return Ok(ToolResult::error(format!(
                "transcribe API error {}: {}",
                status, body
            )));
        }
        let value: serde_json::Value = resp.json().await?;
        let text = value["text"].as_str().unwrap_or("").to_string();
        Ok(ToolResult::ok(vec![Block::Text(text)]))
    }
}

// Minimal base64 decoder (avoid adding a crate). Standard alphabet, no padding strictness.
mod base64_decode {
    pub fn decode(s: &str) -> Result<Vec<u8>, &'static str> {
        fn val(c: u8) -> Option<u8> {
            match c {
                b'A'..=b'Z' => Some(c - b'A'),
                b'a'..=b'z' => Some(c - b'a' + 26),
                b'0'..=b'9' => Some(c - b'0' + 52),
                b'+' => Some(62),
                b'/' => Some(63),
                _ => None,
            }
        }
        let mut out = Vec::with_capacity(s.len() / 4 * 3);
        let mut buf = 0u32;
        let mut bits = 0u32;
        for &c in s.as_bytes() {
            if c == b'=' || c == b'\n' || c == b'\r' {
                continue;
            }
            let v = match val(c) {
                Some(v) => v as u32,
                None => return Err("invalid base64 char"),
            };
            buf = (buf << 6) | v;
            bits += 6;
            if bits >= 8 {
                bits -= 8;
                out.push((buf >> bits) as u8);
            }
        }
        Ok(out)
    }
}