zeptoclaw 0.6.2

Ultra-lightweight personal AI assistant
Documentation
//! Voice transcription tool using Groq Whisper API.

use async_trait::async_trait;
use serde_json::Value;
use std::path::Path;

use crate::error::{Result, ZeptoError};
use crate::security::validate_path_in_workspace;
use crate::tools::{Tool, ToolContext, ToolOutput};

/// Maximum file size accepted for transcription (25 MiB).
const MAX_FILE_BYTES: u64 = 25 * 1024 * 1024;

/// Supported audio extensions for the Groq Whisper API.
const SUPPORTED_EXTENSIONS: &[&str] = &["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "ogg"];

/// Map a file extension to the correct MIME type for the Groq multipart upload.
fn mime_for_extension(ext: &str) -> &'static str {
    match ext {
        "mp3" | "mpeg" | "mpga" => "audio/mpeg",
        "mp4" | "m4a" => "audio/mp4",
        "wav" => "audio/wav",
        "webm" => "audio/webm",
        "ogg" => "audio/ogg",
        _ => "application/octet-stream",
    }
}

pub struct TranscribeTool {
    api_key: String,
    model: String,
    client: reqwest::Client,
}

impl TranscribeTool {
    /// Create a new `TranscribeTool`.
    ///
    /// Returns `Err` if the underlying HTTP client cannot be built.
    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
        let client = reqwest::Client::builder()
            .timeout(std::time::Duration::from_secs(60))
            .build()
            .map_err(|e| ZeptoError::Tool(format!("Failed to build HTTP client: {}", e)))?;

        Ok(Self {
            api_key: api_key.into(),
            model: model.into(),
            client,
        })
    }

    async fn transcribe_file(&self, path: &str) -> Result<String> {
        // --- File size guard ---
        let metadata = tokio::fs::metadata(path)
            .await
            .map_err(|e| ZeptoError::Tool(format!("Failed to read audio file metadata: {}", e)))?;
        if metadata.len() > MAX_FILE_BYTES {
            return Err(ZeptoError::Tool(
                "File too large for transcription (max 25MB)".to_string(),
            ));
        }

        // --- MIME type from extension ---
        let ext = Path::new(path)
            .extension()
            .and_then(|e| e.to_str())
            .unwrap_or("")
            .to_lowercase();

        if !SUPPORTED_EXTENSIONS.contains(&ext.as_str()) {
            return Err(ZeptoError::Tool(format!(
                "Unsupported audio format '{}'. Supported: {}",
                ext,
                SUPPORTED_EXTENSIONS.join(", ")
            )));
        }

        let mime = mime_for_extension(&ext);

        let file_bytes = tokio::fs::read(path)
            .await
            .map_err(|e| ZeptoError::Tool(format!("Failed to read audio file: {}", e)))?;

        let filename = Path::new(path)
            .file_name()
            .and_then(|n| n.to_str())
            .unwrap_or("audio.ogg")
            .to_string();

        let part = reqwest::multipart::Part::bytes(file_bytes)
            .file_name(filename)
            .mime_str(mime)
            .map_err(|e| ZeptoError::Tool(e.to_string()))?;

        let form = reqwest::multipart::Form::new()
            .part("file", part)
            .text("model", self.model.clone());

        let resp = self
            .client
            .post("https://api.groq.com/openai/v1/audio/transcriptions")
            .bearer_auth(&self.api_key)
            .multipart(form)
            .send()
            .await
            .map_err(|e| ZeptoError::Tool(e.to_string()))?;

        if !resp.status().is_success() {
            let status = resp.status().as_u16();
            let body = resp.text().await.unwrap_or_default();
            return Err(ZeptoError::Tool(format!(
                "Groq transcription failed ({}): {}",
                status, body
            )));
        }

        let json: Value = resp
            .json()
            .await
            .map_err(|e| ZeptoError::Tool(e.to_string()))?;

        Ok(json["text"].as_str().unwrap_or("").to_string())
    }
}

#[async_trait]
impl Tool for TranscribeTool {
    fn name(&self) -> &str {
        "transcribe"
    }

    fn description(&self) -> &str {
        "Transcribe a voice or audio file to text using Groq Whisper. \
         Provide the local file path to the audio file. \
         Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, ogg."
    }

    fn compact_description(&self) -> &str {
        "Transcribe an audio file to text via Groq Whisper."
    }

    fn parameters(&self) -> Value {
        serde_json::json!({
            "type": "object",
            "properties": {
                "file_path": {
                    "type": "string",
                    "description": "Absolute or workspace-relative path to the audio file"
                }
            },
            "required": ["file_path"]
        })
    }

    async fn execute(&self, args: Value, ctx: &ToolContext) -> Result<ToolOutput> {
        let file_path = match args["file_path"].as_str() {
            Some(p) => p,
            None => return Ok(ToolOutput::error("file_path is required")),
        };

        let input_path = Path::new(file_path);

        // --- Path validation ---
        let resolved: String = if input_path.is_absolute() {
            // Absolute path: when a workspace is set, it must reside within it.
            if let Some(ws) = &ctx.workspace {
                match validate_path_in_workspace(file_path, ws) {
                    Ok(safe) => safe.as_path().to_string_lossy().to_string(),
                    Err(_) => {
                        return Ok(ToolOutput::error("Path is outside the workspace boundary"))
                    }
                }
            } else {
                // No workspace configured: allow absolute paths as-is.
                file_path.to_string()
            }
        } else {
            // Relative path: reject any `..` components (traversal guard).
            if input_path
                .components()
                .any(|c| matches!(c, std::path::Component::ParentDir))
            {
                return Ok(ToolOutput::error(
                    "Relative paths with '..' parent-directory components are not allowed",
                ));
            }

            if let Some(ws) = &ctx.workspace {
                // Workspace present: validate through the security module.
                match validate_path_in_workspace(file_path, ws) {
                    Ok(safe) => safe.as_path().to_string_lossy().to_string(),
                    Err(_) => {
                        return Ok(ToolOutput::error("Path is outside the workspace boundary"))
                    }
                }
            } else {
                // No workspace: bare relative path, use as-is.
                file_path.to_string()
            }
        };

        match self.transcribe_file(&resolved).await {
            Ok(text) if text.is_empty() => Ok(ToolOutput::llm_only(
                "Transcription returned empty (no speech detected)",
            )),
            Ok(text) => Ok(ToolOutput::user_visible(format!("Transcription: {}", text))),
            Err(e) => Ok(ToolOutput::error(format!("Transcription failed: {}", e))),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_tool() -> TranscribeTool {
        TranscribeTool::new("key", "whisper-large-v3-turbo").unwrap()
    }

    #[test]
    fn test_transcribe_tool_name() {
        let tool = make_tool();
        assert_eq!(tool.name(), "transcribe");
    }

    #[test]
    fn test_transcribe_tool_description() {
        let tool = make_tool();
        assert!(tool.description().contains("Groq Whisper"));
        assert!(tool.description().contains("ogg"));
    }

    #[test]
    fn test_transcribe_tool_parameters() {
        let tool = make_tool();
        let params = tool.parameters();
        assert!(params["properties"]["file_path"].is_object());
        let required = params["required"].as_array().unwrap();
        assert!(required.iter().any(|v| v.as_str() == Some("file_path")));
    }

    #[tokio::test]
    async fn test_transcribe_missing_file_path() {
        let tool = make_tool();
        let ctx = ToolContext::new();
        let result = tool.execute(serde_json::json!({}), &ctx).await.unwrap();
        assert!(result.is_error);
        assert!(result.for_llm.contains("file_path is required"));
    }

    #[tokio::test]
    async fn test_transcribe_nonexistent_file() {
        let tool = make_tool();
        let ctx = ToolContext::new();
        let result = tool
            .execute(
                serde_json::json!({"file_path": "/nonexistent/audio.ogg"}),
                &ctx,
            )
            .await
            .unwrap();
        assert!(result.is_error);
        assert!(result.for_llm.contains("Transcription failed"));
    }

    #[test]
    fn test_relative_path_resolves_with_workspace() {
        // Verify path resolution logic inline
        let ws = "/workspace";
        let rel = "audio.ogg";
        let resolved = if std::path::Path::new(rel).is_absolute() {
            rel.to_string()
        } else {
            format!("{}/{}", ws, rel)
        };
        assert_eq!(resolved, "/workspace/audio.ogg");
    }

    #[test]
    fn test_absolute_path_not_resolved() {
        let abs = "/tmp/audio.ogg";
        let resolved = if std::path::Path::new(abs).is_absolute() {
            abs.to_string()
        } else {
            format!("/workspace/{}", abs)
        };
        assert_eq!(resolved, "/tmp/audio.ogg");
    }

    // ---- new hardening tests ----

    #[test]
    fn test_new_returns_result() {
        let result = TranscribeTool::new("key", "whisper-large-v3-turbo");
        assert!(result.is_ok());
    }

    #[test]
    fn test_mime_for_extension_mp3() {
        assert_eq!(mime_for_extension("mp3"), "audio/mpeg");
        assert_eq!(mime_for_extension("mpeg"), "audio/mpeg");
        assert_eq!(mime_for_extension("mpga"), "audio/mpeg");
    }

    #[test]
    fn test_mime_for_extension_mp4() {
        assert_eq!(mime_for_extension("mp4"), "audio/mp4");
        assert_eq!(mime_for_extension("m4a"), "audio/mp4");
    }

    #[test]
    fn test_mime_for_extension_wav() {
        assert_eq!(mime_for_extension("wav"), "audio/wav");
    }

    #[test]
    fn test_mime_for_extension_webm() {
        assert_eq!(mime_for_extension("webm"), "audio/webm");
    }

    #[test]
    fn test_mime_for_extension_ogg() {
        assert_eq!(mime_for_extension("ogg"), "audio/ogg");
    }

    #[test]
    fn test_mime_for_extension_unknown() {
        assert_eq!(mime_for_extension("flac"), "application/octet-stream");
        assert_eq!(mime_for_extension(""), "application/octet-stream");
    }

    #[tokio::test]
    async fn test_relative_path_with_parent_dir_rejected() {
        let tool = make_tool();
        let mut ctx = ToolContext::new();
        ctx.workspace = None;
        let result = tool
            .execute(serde_json::json!({"file_path": "../secret.ogg"}), &ctx)
            .await
            .unwrap();
        assert!(result.is_error);
        assert!(result.for_llm.contains("'..'"));
    }

    #[tokio::test]
    async fn test_relative_path_deep_traversal_rejected() {
        let tool = make_tool();
        let mut ctx = ToolContext::new();
        ctx.workspace = None;
        let result = tool
            .execute(serde_json::json!({"file_path": "a/../../etc/passwd"}), &ctx)
            .await
            .unwrap();
        assert!(result.is_error);
        assert!(result.for_llm.contains("'..'"));
    }

    #[tokio::test]
    async fn test_absolute_path_outside_workspace_rejected() {
        let tool = make_tool();
        let mut ctx = ToolContext::new();
        ctx.workspace = Some("/workspace".to_string());
        let result = tool
            .execute(serde_json::json!({"file_path": "/etc/passwd"}), &ctx)
            .await
            .unwrap();
        assert!(result.is_error);
        assert!(result.for_llm.contains("workspace boundary"));
    }

    #[tokio::test]
    async fn test_unsupported_extension_rejected() {
        let tool = make_tool();
        let ctx = ToolContext::new();
        // Use an absolute path so we don't trigger workspace check first.
        // The file won't exist so we'd normally hit a read error, but the
        // extension check happens *before* the read — we should see the
        // "Unsupported audio format" message instead.
        let result = tool
            .execute(serde_json::json!({"file_path": "/tmp/audio.flac"}), &ctx)
            .await
            .unwrap();
        assert!(result.is_error);
        // Either the metadata read fails (file doesn't exist) or the extension
        // check triggers — both surface as "Transcription failed".
        assert!(result.for_llm.contains("Transcription failed"));
    }
}