1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4
5use axum::extract::ws::Message;
6use base64::Engine;
7use tokio::sync::{mpsc, Mutex};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum Transport {
12 Twilio,
14 Discord,
16}
17
18pub struct ActiveCall {
20 pub stream_sid: String,
21 pub response_tx: mpsc::Sender<Message>,
22 pub speaking: Arc<AtomicBool>,
23}
24
25#[derive(Clone)]
27pub struct CallEntry {
28 pub stream_sid: String,
29 pub transport: Transport,
30 response_tx: mpsc::Sender<Message>,
31 speaking: Arc<AtomicBool>,
32}
33
34impl CallEntry {
35 pub fn set_speaking(&self, value: bool) {
36 self.speaking.store(value, Ordering::Relaxed);
37 }
38}
39
40#[derive(Clone)]
45pub struct CallRegistry {
46 inner: Arc<Mutex<HashMap<String, CallEntry>>>,
47}
48
49impl Default for CallRegistry {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl CallRegistry {
56 pub fn new() -> Self {
57 Self {
58 inner: Arc::new(Mutex::new(HashMap::new())),
59 }
60 }
61
62 pub async fn register(
64 &self,
65 call_sid: String,
66 stream_sid: String,
67 transport: Transport,
68 response_tx: mpsc::Sender<Message>,
69 speaking: Arc<AtomicBool>,
70 ) {
71 tracing::info!(
72 call_sid = %call_sid,
73 stream_sid = %stream_sid,
74 transport = ?transport,
75 "Call registered"
76 );
77 self.inner.lock().await.insert(
78 call_sid,
79 CallEntry {
80 stream_sid,
81 transport,
82 response_tx,
83 speaking,
84 },
85 );
86 }
87
88 pub async fn deregister(&self, call_sid: &str) {
90 if self.inner.lock().await.remove(call_sid).is_some() {
91 tracing::info!(call_sid = %call_sid, "Call deregistered");
92 }
93 }
94
95 pub async fn get(&self, call_sid: &str) -> Option<CallEntry> {
97 self.inner.lock().await.get(call_sid).cloned()
98 }
99
100 pub async fn send_audio(
106 entry: &CallEntry,
107 mulaw_bytes: &[u8],
108 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
109 match entry.transport {
110 Transport::Twilio => {
111 for chunk in mulaw_bytes.chunks(160) {
112 let b64 = base64::engine::general_purpose::STANDARD.encode(chunk);
113 let msg = serde_json::json!({
114 "event": "media",
115 "streamSid": entry.stream_sid,
116 "media": { "payload": b64 }
117 });
118 entry
119 .response_tx
120 .send(Message::Text(msg.to_string().into()))
121 .await?;
122 }
123
124 let mark = serde_json::json!({
126 "event": "mark",
127 "streamSid": entry.stream_sid,
128 "mark": { "name": "inject_end" }
129 });
130 entry
131 .response_tx
132 .send(Message::Text(mark.to_string().into()))
133 .await?;
134 }
135 Transport::Discord => {
136 for chunk in mulaw_bytes.chunks(160) {
138 let b64 = base64::engine::general_purpose::STANDARD.encode(chunk);
139 let msg = serde_json::json!({
140 "type": "audio",
141 "audio": b64
142 });
143 entry
144 .response_tx
145 .send(Message::Text(msg.to_string().into()))
146 .await?;
147 }
148
149 let mark = serde_json::json!({ "type": "mark" });
151 entry
152 .response_tx
153 .send(Message::Text(mark.to_string().into()))
154 .await?;
155 }
156 }
157
158 Ok(())
159 }
160}