Skip to main content

active_call/handler/
handler.rs

1use crate::{
2    app::AppState,
3    call::{
4        ActiveCall, ActiveCallType, Command,
5        active_call::{ActiveCallGuard, CallParams},
6    },
7    handler::playbook,
8    playbook::{Playbook, PlaybookRunner},
9};
10use crate::{event::SessionEvent, media::track::TrackConfig};
11use axum::{
12    Json, Router,
13    extract::{Path, Query, State, WebSocketUpgrade, ws::Message},
14    response::{IntoResponse, Response},
15    routing::get,
16};
17use bytes::Bytes;
18use chrono::Utc;
19use futures::{SinkExt, StreamExt};
20use rustrtc::IceServer;
21use serde_json::json;
22use std::{path::PathBuf, sync::Arc, time::Duration};
23use tokio::{join, select};
24use tokio_util::sync::CancellationToken;
25use tracing::{debug, info, trace, warn};
26use uuid::Uuid;
27
28fn filter_headers(
29    extras: &mut std::collections::HashMap<String, serde_json::Value>,
30    allowed_headers: &[String],
31) {
32    extras.retain(|k, _| allowed_headers.iter().any(|h| h.eq_ignore_ascii_case(k)));
33}
34
35pub fn call_router() -> Router<AppState> {
36    let r = Router::new()
37        .route("/call", get(ws_handler))
38        .route("/call/webrtc", get(webrtc_handler))
39        .route("/call/sip", get(sip_handler))
40        .route("/list", get(list_active_calls))
41        .route("/kill/{id}", get(kill_active_call));
42    r
43}
44
45pub fn iceservers_router() -> Router<AppState> {
46    let r = Router::new();
47    r.route("/iceservers", get(get_iceservers))
48}
49
50pub fn playbook_router() -> Router<AppState> {
51    Router::new()
52        .route("/api/playbooks", get(playbook::list_playbooks))
53        .route(
54            "/api/playbooks/{name}",
55            get(playbook::get_playbook).post(playbook::save_playbook),
56        )
57        .route(
58            "/api/playbook/run",
59            axum::routing::post(playbook::run_playbook),
60        )
61        .route("/api/records", get(playbook::list_records))
62}
63
64pub async fn ws_handler(
65    ws: WebSocketUpgrade,
66    State(state): State<AppState>,
67    Query(params): Query<CallParams>,
68) -> Response {
69    call_handler(ActiveCallType::WebSocket, ws, state, params).await
70}
71
72pub async fn sip_handler(
73    ws: WebSocketUpgrade,
74    State(state): State<AppState>,
75    Query(params): Query<CallParams>,
76) -> Response {
77    call_handler(ActiveCallType::Sip, ws, state, params).await
78}
79
80pub async fn webrtc_handler(
81    ws: WebSocketUpgrade,
82    State(state): State<AppState>,
83    Query(params): Query<CallParams>,
84) -> Response {
85    call_handler(ActiveCallType::Webrtc, ws, state, params).await
86}
87
88/// Core call handling logic that works with either WebSocket or mpsc channels
89pub async fn call_handler_core(
90    call_type: ActiveCallType,
91    session_id: String,
92    app_state: AppState,
93    cancel_token: CancellationToken,
94    audio_receiver: tokio::sync::mpsc::UnboundedReceiver<Bytes>,
95    server_side_track: Option<String>,
96    dump_events: bool,
97    ping_interval: u64,
98    mut command_receiver: tokio::sync::mpsc::UnboundedReceiver<Command>,
99    event_sender_to_client: tokio::sync::mpsc::UnboundedSender<crate::event::SessionEvent>,
100) {
101    let _cancel_guard = cancel_token.clone().drop_guard();
102    let track_config = TrackConfig::default();
103
104    // Check for pending params (extracted headers)
105    let extras = {
106        let mut pending = app_state.pending_params.lock().await;
107        pending.remove(&session_id)
108    };
109
110    let active_call = Arc::new(ActiveCall::new(
111        call_type.clone(),
112        cancel_token.clone(),
113        session_id.clone(),
114        app_state.invitation.clone(),
115        app_state.clone(),
116        track_config,
117        Some(audio_receiver),
118        dump_events,
119        server_side_track,
120        extras,
121        None,
122    ));
123
124    // Check for pending playbook
125    {
126        let mut pending = app_state.pending_playbooks.lock().await;
127        if let Some(name_or_content) = pending.remove(&session_id) {
128            let playbook_result = if name_or_content.trim().starts_with("---") {
129                Playbook::parse(&name_or_content)
130            } else {
131                // If path already contains config/playbook, use it as-is; otherwise prepend it
132                let path = if name_or_content.starts_with("config/playbook/") {
133                    PathBuf::from(&name_or_content)
134                } else {
135                    PathBuf::from("config/playbook").join(&name_or_content)
136                };
137                Playbook::load(path).await
138            };
139
140            match playbook_result {
141                Ok(mut playbook) => {
142                    // Filter extracted headers if configured (only for SIP calls)
143                    if call_type == ActiveCallType::Sip {
144                        if let Some(sip_config) = &playbook.config.sip {
145                            if let Some(allowed_headers) = &sip_config.extract_headers {
146                                let mut state = active_call.call_state.write().await;
147                                if let Some(extras) = &mut state.extras {
148                                    filter_headers(extras, allowed_headers);
149                                    // Store the list of SIP header keys for later template rendering
150                                    let header_keys: Vec<String> = extras
151                                        .keys()
152                                        .filter(|k| !k.starts_with('_'))
153                                        .cloned()
154                                        .collect();
155                                    extras.insert(
156                                        "_sip_header_keys".to_string(),
157                                        serde_json::to_value(&header_keys).unwrap_or_default(),
158                                    );
159                                    if let Ok(result) = playbook.render(extras) {
160                                        playbook = result;
161                                    }
162                                }
163                            }
164                        }
165                    }
166
167                    match PlaybookRunner::new(playbook, active_call.clone()) {
168                        Ok(runner) => {
169                            crate::spawn(async move {
170                                runner.run().await;
171                            });
172                            let display_name = if name_or_content.trim().starts_with("---") {
173                                "custom content"
174                            } else {
175                                &name_or_content
176                            };
177                            info!(session_id, "Playbook runner started for {}", display_name);
178                        }
179                        Err(e) => {
180                            let display_name = if name_or_content.trim().starts_with("---") {
181                                "custom content"
182                            } else {
183                                &name_or_content
184                            };
185                            warn!(
186                                session_id,
187                                "Failed to create runner {}: {}", display_name, e
188                            )
189                        }
190                    }
191                }
192                Err(e) => {
193                    let display_name = if name_or_content.trim().starts_with("---") {
194                        "custom content"
195                    } else {
196                        &name_or_content
197                    };
198                    warn!(
199                        session_id,
200                        "Failed to load playbook {}: {}", display_name, e
201                    );
202                    let event = SessionEvent::Error {
203                        timestamp: crate::media::get_timestamp(),
204                        track_id: session_id.clone(),
205                        sender: "playbook".to_string(),
206                        error: format!("{}", e),
207                        code: None,
208                    };
209                    event_sender_to_client.send(event).ok();
210                    return;
211                }
212            }
213        }
214    }
215
216    let recv_commands_loop = async {
217        while let Some(command) = command_receiver.recv().await {
218            if let Err(_) = active_call.enqueue_command(command).await {
219                break;
220            }
221        }
222    };
223
224    let mut event_receiver = active_call.event_sender.subscribe();
225    let send_events_loop = async {
226        loop {
227            match event_receiver.recv().await {
228                Ok(event) => {
229                    if let Err(_) = event_sender_to_client.send(event) {
230                        break;
231                    }
232                }
233                Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
234                Err(_) => break,
235            }
236        }
237    };
238
239    let send_ping_loop = async {
240        if ping_interval == 0 {
241            active_call.cancel_token.cancelled().await;
242            return;
243        }
244        let mut ticker = tokio::time::interval(Duration::from_secs(ping_interval));
245        loop {
246            ticker.tick().await;
247            let payload = Utc::now().to_rfc3339();
248            let event = SessionEvent::Ping {
249                timestamp: crate::media::get_timestamp(),
250                payload: Some(payload),
251            };
252            if let Err(_) = active_call.event_sender.send(event) {
253                break;
254            }
255        }
256    };
257
258    let guard = ActiveCallGuard::new(active_call.clone());
259    info!(
260        session_id,
261        active_calls = guard.active_calls,
262        ?call_type,
263        "new call started"
264    );
265    let receiver = active_call.new_receiver();
266
267    let (r, _) = join! {
268        active_call.serve(receiver),
269        async {
270            select!{
271                _ = send_ping_loop => {},
272                _ = cancel_token.cancelled() => {},
273                _ = send_events_loop => { },
274                _ = recv_commands_loop => {
275                    info!(session_id, "Command receiver closed");
276                },
277            }
278            cancel_token.cancel();
279        }
280    };
281    match r {
282        Ok(_) => info!(session_id, "call ended successfully"),
283        Err(e) => warn!(session_id, "call ended with error: {}", e),
284    }
285
286    // Write back final extras to pending_params for post-call processing (e.g. BYE headers)
287    if call_type == ActiveCallType::Sip {
288        let state = active_call.call_state.read().await;
289        if let Some(extras) = &state.extras {
290            let mut pending = app_state.pending_params.lock().await;
291            pending.insert(session_id.clone(), extras.clone());
292        }
293    }
294
295    active_call.cleanup().await.ok();
296    debug!(session_id, "Call handler core completed");
297}
298
299pub async fn call_handler(
300    call_type: ActiveCallType,
301    ws: WebSocketUpgrade,
302    app_state: AppState,
303    params: CallParams,
304) -> Response {
305    let session_id = params
306        .id
307        .unwrap_or_else(|| format!("s.{}", Uuid::new_v4().to_string()));
308    let server_side_track = params.server_side_track.clone();
309    let dump_events = params.dump_events.unwrap_or(true);
310    let ping_interval = params.ping_interval.unwrap_or(20);
311
312    let resp = ws.on_upgrade(move |socket| async move {
313        let (mut ws_sender, mut ws_receiver) = socket.split();
314        let (audio_sender, audio_receiver) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
315        let (command_sender, command_receiver) = tokio::sync::mpsc::unbounded_channel::<Command>();
316        let (event_sender_to_client, mut event_receiver_from_core) =
317            tokio::sync::mpsc::unbounded_channel::<crate::event::SessionEvent>();
318        let cancel_token = CancellationToken::new();
319
320        // Start core handler in background
321        let session_id_clone = session_id.clone();
322        let app_state_clone = app_state.clone();
323        let cancel_token_clone = cancel_token.clone();
324        crate::spawn(call_handler_core(
325            call_type,
326            session_id_clone,
327            app_state_clone,
328            cancel_token_clone,
329            audio_receiver,
330            server_side_track,
331            dump_events,
332            ping_interval.into(),
333            command_receiver,
334            event_sender_to_client,
335        ));
336
337        // Handle WebSocket I/O
338        let recv_from_ws_loop = async {
339            while let Some(Ok(message)) = ws_receiver.next().await {
340                match message {
341                    Message::Text(text) => {
342                        let command = match serde_json::from_str::<Command>(&text) {
343                            Ok(cmd) => cmd,
344                            Err(e) => {
345                                warn!(session_id, %text, "Failed to parse command {}", e);
346                                continue;
347                            }
348                        };
349                        if let Err(_) = command_sender.send(command) {
350                            break;
351                        }
352                    }
353                    Message::Binary(bin) => {
354                        audio_sender.send(bin.into()).ok();
355                    }
356                    Message::Close(_) => {
357                        info!(session_id, "WebSocket closed by client");
358                        break;
359                    }
360                    _ => {}
361                }
362            }
363        };
364
365        let send_to_ws_loop = async {
366            while let Some(event) = event_receiver_from_core.recv().await {
367                trace!(session_id, %event, "Sending WS message");
368                let message = match event.into_ws_message() {
369                    Ok(msg) => msg,
370                    Err(e) => {
371                        warn!(session_id, error=%e, "Failed to serialize event to WS message");
372                        continue;
373                    }
374                };
375                if let Err(_) = ws_sender.send(message).await {
376                    info!(session_id, "WebSocket send failed, closing");
377                    break;
378                }
379            }
380        };
381
382        select! {
383            _ = recv_from_ws_loop => {
384                info!(session_id, "WebSocket receive loop ended");
385            },
386            _ = send_to_ws_loop => {
387                info!(session_id, "WebSocket send loop ended");
388            },
389            _ = cancel_token.cancelled() => {
390                info!(session_id, "WebSocket cancelled");
391            },
392        }
393
394        cancel_token.cancel();
395        ws_sender.flush().await.ok();
396        ws_sender.close().await.ok();
397        debug!(session_id, "WebSocket connection closed");
398    });
399    resp
400}
401
402pub(crate) async fn get_iceservers(State(state): State<AppState>) -> Response {
403    if let Some(ice_servers) = state.config.ice_servers.as_ref() {
404        return Json(ice_servers).into_response();
405    }
406    Json(vec![IceServer {
407        urls: vec!["stun:stun.l.google.com:19302".to_string()],
408        ..Default::default()
409    }])
410    .into_response()
411}
412
413pub(crate) async fn list_active_calls(State(state): State<AppState>) -> Response {
414    let calls = state
415        .active_calls
416        .lock()
417        .unwrap()
418        .iter()
419        .map(|(_, c)| {
420            if let Ok(cs) = c.call_state.try_read() {
421                json!({
422                    "id": c.session_id,
423                    "callType": c.call_type,
424                    "cs.option": cs.option,
425                    "ringTime": cs.ring_time,
426                    "startTime": cs.answer_time,
427                })
428            } else {
429                json!({
430                    "id": c.session_id,
431                    "callType": c.call_type,
432                    "status": "locked",
433                })
434            }
435        })
436        .collect::<Vec<_>>();
437    Json(serde_json::json!({ "active_calls": calls })).into_response()
438}
439
440pub(crate) async fn kill_active_call(
441    Path(id): Path<String>,
442    State(state): State<AppState>,
443) -> Response {
444    let active_calls = state.active_calls.lock().unwrap();
445    if let Some(call) = active_calls.get(&id) {
446        call.cancel_token.cancel();
447        Json(serde_json::json!({ "status": "killed", "id": id })).into_response()
448    } else {
449        Json(serde_json::json!({ "status": "not_found", "id": id })).into_response()
450    }
451}
452
453trait IntoWsMessage {
454    fn into_ws_message(self) -> Result<Message, serde_json::Error>;
455}
456
457impl IntoWsMessage for crate::event::SessionEvent {
458    fn into_ws_message(self) -> Result<Message, serde_json::Error> {
459        match self {
460            SessionEvent::Binary { data, .. } => Ok(Message::Binary(data.into())),
461            SessionEvent::Ping { timestamp, payload } => {
462                let payload = payload.unwrap_or_else(|| timestamp.to_string());
463                Ok(Message::Ping(payload.into()))
464            }
465            event => serde_json::to_string(&event).map(|payload| Message::Text(payload.into())),
466        }
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use serde_json::json;
474    use std::collections::HashMap;
475
476    #[test]
477    fn test_filter_headers() {
478        let mut extras = HashMap::new();
479        extras.insert("X-Tenant-ID".to_string(), json!("123"));
480        extras.insert("X-User-ID".to_string(), json!("456"));
481        extras.insert("Custom-Header".to_string(), json!("abc"));
482        extras.insert("Irrelevant-Header".to_string(), json!("ignore"));
483
484        // Test case-insensitive matching
485        let allowed = vec!["x-tenant-id".to_string(), "Custom-Header".to_string()];
486
487        filter_headers(&mut extras, &allowed);
488
489        assert!(extras.contains_key("X-Tenant-ID"));
490        assert!(extras.contains_key("Custom-Header"));
491        assert!(!extras.contains_key("X-User-ID"));
492        assert!(!extras.contains_key("Irrelevant-Header"));
493
494        // ensure values are preserved
495        assert_eq!(extras.get("X-Tenant-ID").unwrap(), &json!("123"));
496        assert_eq!(extras.get("Custom-Header").unwrap(), &json!("abc"));
497    }
498}