Skip to main content

adk_audio/tools/
transcribe.rs

1//! TranscribeTool — transcribe audio via a configured SttProvider.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::Value;
7
8use crate::traits::{SttOptions, SttProvider};
9
10/// Tool that transcribes audio to text.
11///
12/// Accepts JSON referencing an audio artifact and returns `{text, confidence}`.
13pub struct TranscribeTool {
14    stt: Arc<dyn SttProvider>,
15}
16
17impl TranscribeTool {
18    /// Create a new `TranscribeTool` with the given STT provider.
19    pub fn new(stt: Arc<dyn SttProvider>) -> Self {
20        Self { stt }
21    }
22}
23
24#[async_trait]
25impl adk_core::Tool for TranscribeTool {
26    fn name(&self) -> &str {
27        "transcribe"
28    }
29
30    fn description(&self) -> &str {
31        "Transcribe audio to text"
32    }
33
34    fn parameters_schema(&self) -> Option<Value> {
35        Some(serde_json::json!({
36            "type": "object",
37            "properties": {
38                "audio_data": { "type": "string", "description": "Base64-encoded PCM16 audio data" },
39                "sample_rate": { "type": "integer", "description": "Sample rate in Hz (default 16000)" },
40                "language": { "type": "string", "description": "BCP-47 language hint (optional)" }
41            },
42            "required": ["audio_data"]
43        }))
44    }
45
46    async fn execute(
47        &self,
48        _ctx: Arc<dyn adk_core::ToolContext>,
49        args: Value,
50    ) -> adk_core::Result<Value> {
51        let audio_b64 = args["audio_data"].as_str().unwrap_or_default();
52        let sample_rate = args["sample_rate"].as_u64().unwrap_or(16000) as u32;
53        let language = args["language"].as_str().map(String::from);
54
55        // Decode base64 audio
56        use bytes::Bytes;
57        let data = base64_decode(audio_b64)
58            .map_err(|e| adk_core::AdkError::tool(format!("transcribe: invalid base64: {e}")))?;
59        let frame = crate::frame::AudioFrame::new(Bytes::from(data), sample_rate, 1);
60
61        let opts = SttOptions { language, ..Default::default() };
62        let transcript = self
63            .stt
64            .transcribe(&frame, &opts)
65            .await
66            .map_err(|e| adk_core::AdkError::tool(format!("transcribe: {e}")))?;
67
68        Ok(serde_json::json!({
69            "text": transcript.text,
70            "confidence": transcript.confidence
71        }))
72    }
73}
74
75/// Simple base64 decoder (avoids adding a dependency).
76fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
77    const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
78    let input = input.as_bytes();
79    let mut out = Vec::with_capacity(input.len() * 3 / 4);
80    let mut buf = 0u32;
81    let mut bits = 0u32;
82    for &b in input {
83        if b == b'=' || b == b'\n' || b == b'\r' {
84            continue;
85        }
86        let val = TABLE
87            .iter()
88            .position(|&c| c == b)
89            .ok_or_else(|| format!("invalid base64 character: {}", b as char))?
90            as u32;
91        buf = (buf << 6) | val;
92        bits += 6;
93        if bits >= 8 {
94            bits -= 8;
95            out.push((buf >> bits) as u8);
96            buf &= (1 << bits) - 1;
97        }
98    }
99    Ok(out)
100}