Skip to main content

adk_audio/tools/
apply_fx.rs

1//! ApplyFxTool — apply an FX chain to audio.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use serde_json::Value;
8
9use crate::traits::{AudioProcessor, FxChain};
10
11/// Tool that applies a named FX chain to audio data.
12///
13/// Accepts JSON referencing audio data and an FX chain name.
14pub struct ApplyFxTool {
15    chains: HashMap<String, FxChain>,
16}
17
18impl ApplyFxTool {
19    /// Create a new `ApplyFxTool` with the given named FX chains.
20    pub fn new(chains: HashMap<String, FxChain>) -> Self {
21        Self { chains }
22    }
23}
24
25#[async_trait]
26impl adk_core::Tool for ApplyFxTool {
27    fn name(&self) -> &str {
28        "apply_fx"
29    }
30
31    fn description(&self) -> &str {
32        "Apply audio effects chain to audio data"
33    }
34
35    fn parameters_schema(&self) -> Option<Value> {
36        let chain_names: Vec<&str> = self.chains.keys().map(|s| s.as_str()).collect();
37        Some(serde_json::json!({
38            "type": "object",
39            "properties": {
40                "audio_data": { "type": "string", "description": "Base64-encoded PCM16 audio data" },
41                "sample_rate": { "type": "integer", "description": "Sample rate in Hz (default 16000)" },
42                "chain": { "type": "string", "description": "FX chain name", "enum": chain_names }
43            },
44            "required": ["audio_data", "chain"]
45        }))
46    }
47
48    async fn execute(
49        &self,
50        _ctx: Arc<dyn adk_core::ToolContext>,
51        args: Value,
52    ) -> adk_core::Result<Value> {
53        let chain_name = args["chain"].as_str().unwrap_or_default();
54        let chain = self.chains.get(chain_name).ok_or_else(|| {
55            adk_core::AdkError::tool(format!("apply_fx: unknown chain '{chain_name}'"))
56        })?;
57
58        let audio_b64 = args["audio_data"].as_str().unwrap_or_default();
59        let sample_rate = args["sample_rate"].as_u64().unwrap_or(16000) as u32;
60
61        // Decode base64
62        let data = base64_decode(audio_b64)
63            .map_err(|e| adk_core::AdkError::tool(format!("apply_fx: invalid base64: {e}")))?;
64        let frame = crate::frame::AudioFrame::new(bytes::Bytes::from(data), sample_rate, 1);
65
66        let processed = chain
67            .process(&frame)
68            .await
69            .map_err(|e| adk_core::AdkError::tool(format!("apply_fx: {e}")))?;
70
71        Ok(serde_json::json!({
72            "duration_ms": processed.duration_ms,
73            "sample_rate": processed.sample_rate
74        }))
75    }
76}
77
78fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
79    const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
80    let input = input.as_bytes();
81    let mut out = Vec::with_capacity(input.len() * 3 / 4);
82    let mut buf = 0u32;
83    let mut bits = 0u32;
84    for &b in input {
85        if b == b'=' || b == b'\n' || b == b'\r' {
86            continue;
87        }
88        let val = TABLE
89            .iter()
90            .position(|&c| c == b)
91            .ok_or_else(|| format!("invalid base64 character: {}", b as char))?
92            as u32;
93        buf = (buf << 6) | val;
94        bits += 6;
95        if bits >= 8 {
96            bits -= 8;
97            out.push((buf >> bits) as u8);
98            buf &= (1 << bits) - 1;
99        }
100    }
101    Ok(out)
102}