adk_audio/tools/
apply_fx.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use serde_json::Value;
8
9use crate::traits::{AudioProcessor, FxChain};
10
11pub struct ApplyFxTool {
15 chains: HashMap<String, FxChain>,
16}
17
18impl ApplyFxTool {
19 pub fn new(chains: HashMap<String, FxChain>) -> Self {
21 Self { chains }
22 }
23}
24
25#[async_trait]
26impl adk_core::Tool for ApplyFxTool {
27 fn name(&self) -> &str {
28 "apply_fx"
29 }
30
31 fn description(&self) -> &str {
32 "Apply audio effects chain to audio data"
33 }
34
35 fn parameters_schema(&self) -> Option<Value> {
36 let chain_names: Vec<&str> = self.chains.keys().map(|s| s.as_str()).collect();
37 Some(serde_json::json!({
38 "type": "object",
39 "properties": {
40 "audio_data": { "type": "string", "description": "Base64-encoded PCM16 audio data" },
41 "sample_rate": { "type": "integer", "description": "Sample rate in Hz (default 16000)" },
42 "chain": { "type": "string", "description": "FX chain name", "enum": chain_names }
43 },
44 "required": ["audio_data", "chain"]
45 }))
46 }
47
48 async fn execute(
49 &self,
50 _ctx: Arc<dyn adk_core::ToolContext>,
51 args: Value,
52 ) -> adk_core::Result<Value> {
53 let chain_name = args["chain"].as_str().unwrap_or_default();
54 let chain = self.chains.get(chain_name).ok_or_else(|| {
55 adk_core::AdkError::tool(format!("apply_fx: unknown chain '{chain_name}'"))
56 })?;
57
58 let audio_b64 = args["audio_data"].as_str().unwrap_or_default();
59 let sample_rate = args["sample_rate"].as_u64().unwrap_or(16000) as u32;
60
61 let data = base64_decode(audio_b64)
63 .map_err(|e| adk_core::AdkError::tool(format!("apply_fx: invalid base64: {e}")))?;
64 let frame = crate::frame::AudioFrame::new(bytes::Bytes::from(data), sample_rate, 1);
65
66 let processed = chain
67 .process(&frame)
68 .await
69 .map_err(|e| adk_core::AdkError::tool(format!("apply_fx: {e}")))?;
70
71 Ok(serde_json::json!({
72 "duration_ms": processed.duration_ms,
73 "sample_rate": processed.sample_rate
74 }))
75 }
76}
77
78fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
79 const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
80 let input = input.as_bytes();
81 let mut out = Vec::with_capacity(input.len() * 3 / 4);
82 let mut buf = 0u32;
83 let mut bits = 0u32;
84 for &b in input {
85 if b == b'=' || b == b'\n' || b == b'\r' {
86 continue;
87 }
88 let val = TABLE
89 .iter()
90 .position(|&c| c == b)
91 .ok_or_else(|| format!("invalid base64 character: {}", b as char))?
92 as u32;
93 buf = (buf << 6) | val;
94 bits += 6;
95 if bits >= 8 {
96 bits -= 8;
97 out.push((buf >> bits) as u8);
98 buf &= (1 << bits) - 1;
99 }
100 }
101 Ok(out)
102}