adk_audio/tools/
transcribe.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::Value;
7
8use crate::traits::{SttOptions, SttProvider};
9
10pub struct TranscribeTool {
14 stt: Arc<dyn SttProvider>,
15}
16
17impl TranscribeTool {
18 pub fn new(stt: Arc<dyn SttProvider>) -> Self {
20 Self { stt }
21 }
22}
23
24#[async_trait]
25impl adk_core::Tool for TranscribeTool {
26 fn name(&self) -> &str {
27 "transcribe"
28 }
29
30 fn description(&self) -> &str {
31 "Transcribe audio to text"
32 }
33
34 fn parameters_schema(&self) -> Option<Value> {
35 Some(serde_json::json!({
36 "type": "object",
37 "properties": {
38 "audio_data": { "type": "string", "description": "Base64-encoded PCM16 audio data" },
39 "sample_rate": { "type": "integer", "description": "Sample rate in Hz (default 16000)" },
40 "language": { "type": "string", "description": "BCP-47 language hint (optional)" }
41 },
42 "required": ["audio_data"]
43 }))
44 }
45
46 async fn execute(
47 &self,
48 _ctx: Arc<dyn adk_core::ToolContext>,
49 args: Value,
50 ) -> adk_core::Result<Value> {
51 let audio_b64 = args["audio_data"].as_str().unwrap_or_default();
52 let sample_rate = args["sample_rate"].as_u64().unwrap_or(16000) as u32;
53 let language = args["language"].as_str().map(String::from);
54
55 use bytes::Bytes;
57 let data = base64_decode(audio_b64)
58 .map_err(|e| adk_core::AdkError::tool(format!("transcribe: invalid base64: {e}")))?;
59 let frame = crate::frame::AudioFrame::new(Bytes::from(data), sample_rate, 1);
60
61 let opts = SttOptions { language, ..Default::default() };
62 let transcript = self
63 .stt
64 .transcribe(&frame, &opts)
65 .await
66 .map_err(|e| adk_core::AdkError::tool(format!("transcribe: {e}")))?;
67
68 Ok(serde_json::json!({
69 "text": transcript.text,
70 "confidence": transcript.confidence
71 }))
72 }
73}
74
75fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
77 const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
78 let input = input.as_bytes();
79 let mut out = Vec::with_capacity(input.len() * 3 / 4);
80 let mut buf = 0u32;
81 let mut bits = 0u32;
82 for &b in input {
83 if b == b'=' || b == b'\n' || b == b'\r' {
84 continue;
85 }
86 let val = TABLE
87 .iter()
88 .position(|&c| c == b)
89 .ok_or_else(|| format!("invalid base64 character: {}", b as char))?
90 as u32;
91 buf = (buf << 6) | val;
92 bits += 6;
93 if bits >= 8 {
94 bits -= 8;
95 out.push((buf >> bits) as u8);
96 buf &= (1 << bits) - 1;
97 }
98 }
99 Ok(out)
100}