use async_trait::async_trait;
use serde_json::Value;
use std::path::Path;
use crate::error::{Result, ZeptoError};
use crate::security::validate_path_in_workspace;
use crate::tools::{Tool, ToolContext, ToolOutput};
const MAX_FILE_BYTES: u64 = 25 * 1024 * 1024;
const SUPPORTED_EXTENSIONS: &[&str] = &["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "ogg"];
fn mime_for_extension(ext: &str) -> &'static str {
match ext {
"mp3" | "mpeg" | "mpga" => "audio/mpeg",
"mp4" | "m4a" => "audio/mp4",
"wav" => "audio/wav",
"webm" => "audio/webm",
"ogg" => "audio/ogg",
_ => "application/octet-stream",
}
}
pub struct TranscribeTool {
api_key: String,
model: String,
client: reqwest::Client,
}
impl TranscribeTool {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.map_err(|e| ZeptoError::Tool(format!("Failed to build HTTP client: {}", e)))?;
Ok(Self {
api_key: api_key.into(),
model: model.into(),
client,
})
}
async fn transcribe_file(&self, path: &str) -> Result<String> {
let metadata = tokio::fs::metadata(path)
.await
.map_err(|e| ZeptoError::Tool(format!("Failed to read audio file metadata: {}", e)))?;
if metadata.len() > MAX_FILE_BYTES {
return Err(ZeptoError::Tool(
"File too large for transcription (max 25MB)".to_string(),
));
}
let ext = Path::new(path)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_lowercase();
if !SUPPORTED_EXTENSIONS.contains(&ext.as_str()) {
return Err(ZeptoError::Tool(format!(
"Unsupported audio format '{}'. Supported: {}",
ext,
SUPPORTED_EXTENSIONS.join(", ")
)));
}
let mime = mime_for_extension(&ext);
let file_bytes = tokio::fs::read(path)
.await
.map_err(|e| ZeptoError::Tool(format!("Failed to read audio file: {}", e)))?;
let filename = Path::new(path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("audio.ogg")
.to_string();
let part = reqwest::multipart::Part::bytes(file_bytes)
.file_name(filename)
.mime_str(mime)
.map_err(|e| ZeptoError::Tool(e.to_string()))?;
let form = reqwest::multipart::Form::new()
.part("file", part)
.text("model", self.model.clone());
let resp = self
.client
.post("https://api.groq.com/openai/v1/audio/transcriptions")
.bearer_auth(&self.api_key)
.multipart(form)
.send()
.await
.map_err(|e| ZeptoError::Tool(e.to_string()))?;
if !resp.status().is_success() {
let status = resp.status().as_u16();
let body = resp.text().await.unwrap_or_default();
return Err(ZeptoError::Tool(format!(
"Groq transcription failed ({}): {}",
status, body
)));
}
let json: Value = resp
.json()
.await
.map_err(|e| ZeptoError::Tool(e.to_string()))?;
Ok(json["text"].as_str().unwrap_or("").to_string())
}
}
#[async_trait]
impl Tool for TranscribeTool {
fn name(&self) -> &str {
"transcribe"
}
fn description(&self) -> &str {
"Transcribe a voice or audio file to text using Groq Whisper. \
Provide the local file path to the audio file. \
Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, ogg."
}
fn compact_description(&self) -> &str {
"Transcribe an audio file to text via Groq Whisper."
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Absolute or workspace-relative path to the audio file"
}
},
"required": ["file_path"]
})
}
async fn execute(&self, args: Value, ctx: &ToolContext) -> Result<ToolOutput> {
let file_path = match args["file_path"].as_str() {
Some(p) => p,
None => return Ok(ToolOutput::error("file_path is required")),
};
let input_path = Path::new(file_path);
let resolved: String = if input_path.is_absolute() {
if let Some(ws) = &ctx.workspace {
match validate_path_in_workspace(file_path, ws) {
Ok(safe) => safe.as_path().to_string_lossy().to_string(),
Err(_) => {
return Ok(ToolOutput::error("Path is outside the workspace boundary"))
}
}
} else {
file_path.to_string()
}
} else {
if input_path
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
{
return Ok(ToolOutput::error(
"Relative paths with '..' parent-directory components are not allowed",
));
}
if let Some(ws) = &ctx.workspace {
match validate_path_in_workspace(file_path, ws) {
Ok(safe) => safe.as_path().to_string_lossy().to_string(),
Err(_) => {
return Ok(ToolOutput::error("Path is outside the workspace boundary"))
}
}
} else {
file_path.to_string()
}
};
match self.transcribe_file(&resolved).await {
Ok(text) if text.is_empty() => Ok(ToolOutput::llm_only(
"Transcription returned empty (no speech detected)",
)),
Ok(text) => Ok(ToolOutput::user_visible(format!("Transcription: {}", text))),
Err(e) => Ok(ToolOutput::error(format!("Transcription failed: {}", e))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tool() -> TranscribeTool {
TranscribeTool::new("key", "whisper-large-v3-turbo").unwrap()
}
#[test]
fn test_transcribe_tool_name() {
let tool = make_tool();
assert_eq!(tool.name(), "transcribe");
}
#[test]
fn test_transcribe_tool_description() {
let tool = make_tool();
assert!(tool.description().contains("Groq Whisper"));
assert!(tool.description().contains("ogg"));
}
#[test]
fn test_transcribe_tool_parameters() {
let tool = make_tool();
let params = tool.parameters();
assert!(params["properties"]["file_path"].is_object());
let required = params["required"].as_array().unwrap();
assert!(required.iter().any(|v| v.as_str() == Some("file_path")));
}
#[tokio::test]
async fn test_transcribe_missing_file_path() {
let tool = make_tool();
let ctx = ToolContext::new();
let result = tool.execute(serde_json::json!({}), &ctx).await.unwrap();
assert!(result.is_error);
assert!(result.for_llm.contains("file_path is required"));
}
#[tokio::test]
async fn test_transcribe_nonexistent_file() {
let tool = make_tool();
let ctx = ToolContext::new();
let result = tool
.execute(
serde_json::json!({"file_path": "/nonexistent/audio.ogg"}),
&ctx,
)
.await
.unwrap();
assert!(result.is_error);
assert!(result.for_llm.contains("Transcription failed"));
}
#[test]
fn test_relative_path_resolves_with_workspace() {
let ws = "/workspace";
let rel = "audio.ogg";
let resolved = if std::path::Path::new(rel).is_absolute() {
rel.to_string()
} else {
format!("{}/{}", ws, rel)
};
assert_eq!(resolved, "/workspace/audio.ogg");
}
#[test]
fn test_absolute_path_not_resolved() {
let abs = "/tmp/audio.ogg";
let resolved = if std::path::Path::new(abs).is_absolute() {
abs.to_string()
} else {
format!("/workspace/{}", abs)
};
assert_eq!(resolved, "/tmp/audio.ogg");
}
#[test]
fn test_new_returns_result() {
let result = TranscribeTool::new("key", "whisper-large-v3-turbo");
assert!(result.is_ok());
}
#[test]
fn test_mime_for_extension_mp3() {
assert_eq!(mime_for_extension("mp3"), "audio/mpeg");
assert_eq!(mime_for_extension("mpeg"), "audio/mpeg");
assert_eq!(mime_for_extension("mpga"), "audio/mpeg");
}
#[test]
fn test_mime_for_extension_mp4() {
assert_eq!(mime_for_extension("mp4"), "audio/mp4");
assert_eq!(mime_for_extension("m4a"), "audio/mp4");
}
#[test]
fn test_mime_for_extension_wav() {
assert_eq!(mime_for_extension("wav"), "audio/wav");
}
#[test]
fn test_mime_for_extension_webm() {
assert_eq!(mime_for_extension("webm"), "audio/webm");
}
#[test]
fn test_mime_for_extension_ogg() {
assert_eq!(mime_for_extension("ogg"), "audio/ogg");
}
#[test]
fn test_mime_for_extension_unknown() {
assert_eq!(mime_for_extension("flac"), "application/octet-stream");
assert_eq!(mime_for_extension(""), "application/octet-stream");
}
#[tokio::test]
async fn test_relative_path_with_parent_dir_rejected() {
let tool = make_tool();
let mut ctx = ToolContext::new();
ctx.workspace = None;
let result = tool
.execute(serde_json::json!({"file_path": "../secret.ogg"}), &ctx)
.await
.unwrap();
assert!(result.is_error);
assert!(result.for_llm.contains("'..'"));
}
#[tokio::test]
async fn test_relative_path_deep_traversal_rejected() {
let tool = make_tool();
let mut ctx = ToolContext::new();
ctx.workspace = None;
let result = tool
.execute(serde_json::json!({"file_path": "a/../../etc/passwd"}), &ctx)
.await
.unwrap();
assert!(result.is_error);
assert!(result.for_llm.contains("'..'"));
}
#[tokio::test]
async fn test_absolute_path_outside_workspace_rejected() {
let tool = make_tool();
let mut ctx = ToolContext::new();
ctx.workspace = Some("/workspace".to_string());
let result = tool
.execute(serde_json::json!({"file_path": "/etc/passwd"}), &ctx)
.await
.unwrap();
assert!(result.is_error);
assert!(result.for_llm.contains("workspace boundary"));
}
#[tokio::test]
async fn test_unsupported_extension_rejected() {
let tool = make_tool();
let ctx = ToolContext::new();
let result = tool
.execute(serde_json::json!({"file_path": "/tmp/audio.flac"}), &ctx)
.await
.unwrap();
assert!(result.is_error);
assert!(result.for_llm.contains("Transcription failed"));
}
}