Skip to main content

active_call/handler/
handler.rs

1use crate::{
2    app::AppState,
3    call::{
4        ActiveCall, ActiveCallType, Command,
5        active_call::{ActiveCallGuard, CallParams},
6    },
7    handler::playbook,
8    playbook::{Playbook, PlaybookRunner},
9};
10use crate::{event::SessionEvent, media::track::TrackConfig};
11use axum::{
12    Json, Router,
13    extract::{Path, Query, State, WebSocketUpgrade, ws::Message},
14    response::{IntoResponse, Response},
15    routing::get,
16};
17use bytes::Bytes;
18use chrono::Utc;
19use futures::{SinkExt, StreamExt};
20use rustrtc::IceServer;
21use serde_json::json;
22use std::collections::HashMap;
23use std::{path::PathBuf, sync::Arc, time::Duration};
24use tokio::{join, select};
25use tokio_util::sync::CancellationToken;
26use tracing::{debug, info, trace, warn};
27use uuid::Uuid;
28
29fn filter_headers(
30    extras: &mut std::collections::HashMap<String, serde_json::Value>,
31    allowed_headers: &[String],
32) {
33    extras.retain(|k, _| allowed_headers.iter().any(|h| h.eq_ignore_ascii_case(k)));
34}
35
36pub fn call_router() -> Router<AppState> {
37    let r = Router::new()
38        .route("/call", get(ws_handler))
39        .route("/call/webrtc", get(webrtc_handler))
40        .route("/call/sip", get(sip_handler))
41        .route("/list", get(list_active_calls))
42        .route("/kill/{id}", get(kill_active_call));
43    r
44}
45
46pub fn iceservers_router() -> Router<AppState> {
47    let r = Router::new();
48    r.route("/iceservers", get(get_iceservers))
49}
50
51pub fn playbook_router() -> Router<AppState> {
52    Router::new()
53        .route("/api/playbooks", get(playbook::list_playbooks))
54        .route(
55            "/api/playbooks/{name}",
56            get(playbook::get_playbook).post(playbook::save_playbook),
57        )
58        .route(
59            "/api/playbook/run",
60            axum::routing::post(playbook::run_playbook),
61        )
62        .route("/api/records", get(playbook::list_records))
63}
64
65pub async fn ws_handler(
66    ws: WebSocketUpgrade,
67    State(state): State<AppState>,
68    Query(params): Query<CallParams>,
69) -> Response {
70    call_handler(ActiveCallType::WebSocket, ws, state, params).await
71}
72
73pub async fn sip_handler(
74    ws: WebSocketUpgrade,
75    State(state): State<AppState>,
76    Query(params): Query<CallParams>,
77) -> Response {
78    call_handler(ActiveCallType::Sip, ws, state, params).await
79}
80
81pub async fn webrtc_handler(
82    ws: WebSocketUpgrade,
83    State(state): State<AppState>,
84    Query(params): Query<CallParams>,
85) -> Response {
86    call_handler(ActiveCallType::Webrtc, ws, state, params).await
87}
88
89/// Core call handling logic that works with either WebSocket or mpsc channels
90///
91/// `extras` and `playbook_name` are session-scoped parameters passed directly
92/// by the caller (SIP handler, CLI, etc.) instead of through global maps.
93/// Returns the final call extras (including `_hangup_headers` if set) so the
94/// caller can use them for SIP BYE or other post-call processing.
95pub async fn call_handler_core(
96    call_type: ActiveCallType,
97    session_id: String,
98    app_state: AppState,
99    cancel_token: CancellationToken,
100    audio_receiver: tokio::sync::mpsc::UnboundedReceiver<Bytes>,
101    server_side_track: Option<String>,
102    dump_events: bool,
103    ping_interval: u64,
104    mut command_receiver: tokio::sync::mpsc::UnboundedReceiver<Command>,
105    event_sender_to_client: tokio::sync::mpsc::UnboundedSender<crate::event::SessionEvent>,
106    extras: Option<HashMap<String, serde_json::Value>>,
107    playbook_name: Option<String>,
108) -> Option<HashMap<String, serde_json::Value>> {
109    let _cancel_guard = cancel_token.clone().drop_guard();
110    let track_config = TrackConfig::default();
111
112    let active_call = Arc::new(ActiveCall::new(
113        call_type.clone(),
114        cancel_token.clone(),
115        session_id.clone(),
116        app_state.invitation.clone(),
117        app_state.clone(),
118        track_config,
119        Some(audio_receiver),
120        dump_events,
121        server_side_track,
122        extras,
123        None,
124    ));
125
126    // Load playbook: prefer direct parameter, fall back to pending_playbooks
127    // (pending_playbooks is used by the run_playbook HTTP endpoint)
128    {
129        let name_or_content = playbook_name.or_else(|| {
130            app_state
131                .pending_playbooks
132                .try_lock()
133                .ok()
134                .and_then(|mut pending| pending.remove(&session_id).map(|(val, _)| val))
135        });
136        if let Some(name_or_content) = name_or_content {
137            let playbook_result = if name_or_content.trim().starts_with("---") {
138                Playbook::parse(&name_or_content)
139            } else {
140                // If path already contains config/playbook, use it as-is; otherwise prepend it
141                let path = if name_or_content.starts_with("config/playbook/") {
142                    PathBuf::from(&name_or_content)
143                } else {
144                    PathBuf::from("config/playbook").join(&name_or_content)
145                };
146                Playbook::load(path).await
147            };
148
149            match playbook_result {
150                Ok(mut playbook) => {
151                    // Filter extracted headers if configured (only for SIP calls)
152                    if call_type == ActiveCallType::Sip {
153                        if let Some(sip_config) = &playbook.config.sip {
154                            if let Some(allowed_headers) = &sip_config.extract_headers {
155                                let mut state = active_call.call_state.write().await;
156                                if let Some(extras) = &mut state.extras {
157                                    filter_headers(extras, allowed_headers);
158                                    // Store the list of SIP header keys for later template rendering
159                                    let header_keys: Vec<String> = extras
160                                        .keys()
161                                        .filter(|k| !k.starts_with('_'))
162                                        .cloned()
163                                        .collect();
164                                    extras.insert(
165                                        "_sip_header_keys".to_string(),
166                                        serde_json::to_value(&header_keys).unwrap_or_default(),
167                                    );
168                                    if let Ok(result) = playbook.render(extras) {
169                                        playbook = result;
170                                    }
171                                }
172                            }
173                        }
174                    }
175
176                    match PlaybookRunner::new(playbook, active_call.clone()) {
177                        Ok(runner) => {
178                            crate::spawn(async move {
179                                runner.run().await;
180                            });
181                            let display_name = if name_or_content.trim().starts_with("---") {
182                                "custom content"
183                            } else {
184                                &name_or_content
185                            };
186                            info!(session_id, "Playbook runner started for {}", display_name);
187                        }
188                        Err(e) => {
189                            let display_name = if name_or_content.trim().starts_with("---") {
190                                "custom content"
191                            } else {
192                                &name_or_content
193                            };
194                            warn!(
195                                session_id,
196                                "Failed to create runner {}: {}", display_name, e
197                            )
198                        }
199                    }
200                }
201                Err(e) => {
202                    let display_name = if name_or_content.trim().starts_with("---") {
203                        "custom content"
204                    } else {
205                        &name_or_content
206                    };
207                    warn!(
208                        session_id,
209                        "Failed to load playbook {}: {}", display_name, e
210                    );
211                    let event = SessionEvent::Error {
212                        timestamp: crate::media::get_timestamp(),
213                        track_id: session_id.clone(),
214                        sender: "playbook".to_string(),
215                        error: format!("{}", e),
216                        code: None,
217                    };
218                    event_sender_to_client.send(event).ok();
219                    return None;
220                }
221            }
222        }
223    }
224
225    let recv_commands_loop = async {
226        while let Some(command) = command_receiver.recv().await {
227            if let Err(_) = active_call.enqueue_command(command).await {
228                break;
229            }
230        }
231    };
232
233    let mut event_receiver = active_call.event_sender.subscribe();
234    let send_events_loop = async {
235        loop {
236            match event_receiver.recv().await {
237                Ok(event) => {
238                    if let Err(_) = event_sender_to_client.send(event) {
239                        break;
240                    }
241                }
242                Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
243                Err(_) => break,
244            }
245        }
246    };
247
248    let send_ping_loop = async {
249        if ping_interval == 0 {
250            active_call.cancel_token.cancelled().await;
251            return;
252        }
253        let mut ticker = tokio::time::interval(Duration::from_secs(ping_interval));
254        loop {
255            ticker.tick().await;
256            let payload = Utc::now().to_rfc3339();
257            let event = SessionEvent::Ping {
258                timestamp: crate::media::get_timestamp(),
259                payload: Some(payload),
260            };
261            if let Err(_) = active_call.event_sender.send(event) {
262                break;
263            }
264        }
265    };
266
267    let guard = ActiveCallGuard::new(active_call.clone());
268    info!(
269        session_id,
270        active_calls = guard.active_calls,
271        ?call_type,
272        "new call started"
273    );
274    let receiver = active_call.new_receiver();
275
276    let (r, _) = join! {
277        active_call.serve(receiver),
278        async {
279            select!{
280                _ = send_ping_loop => {},
281                _ = cancel_token.cancelled() => {},
282                _ = send_events_loop => { },
283                _ = recv_commands_loop => {
284                    info!(session_id, "Command receiver closed");
285                },
286            }
287            cancel_token.cancel();
288        }
289    };
290    // drain events
291    while let Ok(event) = event_receiver.try_recv() {
292        if let Err(_) = event_sender_to_client.send(event) {
293            break;
294        }
295    }
296    match r {
297        Ok(_) => info!(session_id, "call ended successfully"),
298        Err(e) => warn!(session_id, "call ended with error: {}", e),
299    }
300
301    // Capture final extras (including _hangup_headers) before cleanup
302    let final_extras = active_call.call_state.read().await.extras.clone();
303
304    active_call.cleanup().await.ok();
305    debug!(session_id, "Call handler core completed");
306
307    final_extras
308}
309
310pub async fn call_handler(
311    call_type: ActiveCallType,
312    ws: WebSocketUpgrade,
313    app_state: AppState,
314    params: CallParams,
315) -> Response {
316    let session_id = params
317        .id
318        .unwrap_or_else(|| format!("s.{}", Uuid::new_v4().to_string()));
319    let server_side_track = params.server_side_track.clone();
320    let dump_events = params.dump_events.unwrap_or(true);
321    let ping_interval = params.ping_interval.unwrap_or(20);
322
323    let resp = ws.on_upgrade(move |socket| async move {
324        let (mut ws_sender, mut ws_receiver) = socket.split();
325        let (audio_sender, audio_receiver) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
326        let (command_sender, command_receiver) = tokio::sync::mpsc::unbounded_channel::<Command>();
327        let (event_sender_to_client, mut event_receiver_from_core) =
328            tokio::sync::mpsc::unbounded_channel::<crate::event::SessionEvent>();
329        let cancel_token = CancellationToken::new();
330
331        // Start core handler in background
332        let session_id_clone = session_id.clone();
333        let app_state_clone = app_state.clone();
334        let cancel_token_clone = cancel_token.clone();
335        crate::spawn(async move {
336            call_handler_core(
337                call_type,
338                session_id_clone,
339                app_state_clone,
340                cancel_token_clone,
341                audio_receiver,
342                server_side_track,
343                dump_events,
344                ping_interval.into(),
345                command_receiver,
346                event_sender_to_client,
347                None, // extras — not used for WebSocket calls
348                None, // playbook_name — falls back to pending_playbooks
349            )
350            .await;
351        });
352
353        // Handle WebSocket I/O
354        let recv_from_ws_loop = async {
355            while let Some(Ok(message)) = ws_receiver.next().await {
356                match message {
357                    Message::Text(text) => {
358                        let command = match serde_json::from_str::<Command>(&text) {
359                            Ok(cmd) => cmd,
360                            Err(e) => {
361                                warn!(session_id, %text, "Failed to parse command {}", e);
362                                continue;
363                            }
364                        };
365                        if let Err(_) = command_sender.send(command) {
366                            break;
367                        }
368                    }
369                    Message::Binary(bin) => {
370                        audio_sender.send(bin.into()).ok();
371                    }
372                    Message::Close(_) => {
373                        info!(session_id, "WebSocket closed by client");
374                        break;
375                    }
376                    _ => {}
377                }
378            }
379        };
380
381        let send_to_ws_loop = async {
382            while let Some(event) = event_receiver_from_core.recv().await {
383                trace!(session_id, %event, "Sending WS message");
384                let message = match event.into_ws_message() {
385                    Ok(msg) => msg,
386                    Err(e) => {
387                        warn!(session_id, error=%e, "Failed to serialize event to WS message");
388                        continue;
389                    }
390                };
391                if let Err(_) = ws_sender.send(message).await {
392                    info!(session_id, "WebSocket send failed, closing");
393                    break;
394                }
395            }
396        };
397
398        select! {
399            _ = recv_from_ws_loop => {
400                info!(session_id, "WebSocket receive loop ended");
401            },
402            _ = send_to_ws_loop => {
403                info!(session_id, "WebSocket send loop ended");
404            },
405        }
406
407        cancel_token.cancel();
408        ws_sender.flush().await.ok();
409        ws_sender.close().await.ok();
410        debug!(session_id, "WebSocket connection closed");
411    });
412    resp
413}
414
415pub(crate) async fn get_iceservers(State(state): State<AppState>) -> Response {
416    if let Some(ice_servers) = state.config.ice_servers.as_ref() {
417        return Json(ice_servers).into_response();
418    }
419    Json(vec![IceServer {
420        urls: vec!["stun:stun.l.google.com:19302".to_string()],
421        ..Default::default()
422    }])
423    .into_response()
424}
425
426pub(crate) async fn list_active_calls(State(state): State<AppState>) -> Response {
427    let calls = state
428        .active_calls
429        .lock()
430        .unwrap()
431        .iter()
432        .map(|(_, c)| {
433            if let Ok(cs) = c.call_state.try_read() {
434                json!({
435                    "id": c.session_id,
436                    "callType": c.call_type,
437                    "cs.option": cs.option,
438                    "ringTime": cs.ring_time,
439                    "startTime": cs.answer_time,
440                })
441            } else {
442                json!({
443                    "id": c.session_id,
444                    "callType": c.call_type,
445                    "status": "locked",
446                })
447            }
448        })
449        .collect::<Vec<_>>();
450    Json(serde_json::json!({ "active_calls": calls })).into_response()
451}
452
453pub(crate) async fn kill_active_call(
454    Path(id): Path<String>,
455    State(state): State<AppState>,
456) -> Response {
457    let active_calls = state.active_calls.lock().unwrap();
458    if let Some(call) = active_calls.get(&id) {
459        call.cancel_token.cancel();
460        Json(serde_json::json!({ "status": "killed", "id": id })).into_response()
461    } else {
462        Json(serde_json::json!({ "status": "not_found", "id": id })).into_response()
463    }
464}
465
466trait IntoWsMessage {
467    fn into_ws_message(self) -> Result<Message, serde_json::Error>;
468}
469
470impl IntoWsMessage for crate::event::SessionEvent {
471    fn into_ws_message(self) -> Result<Message, serde_json::Error> {
472        match self {
473            SessionEvent::Binary { data, .. } => Ok(Message::Binary(data.into())),
474            SessionEvent::Ping { timestamp, payload } => {
475                let payload = payload.unwrap_or_else(|| timestamp.to_string());
476                Ok(Message::Ping(payload.into()))
477            }
478            event => serde_json::to_string(&event).map(|payload| Message::Text(payload.into())),
479        }
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use serde_json::json;
487    use std::collections::HashMap;
488
489    #[test]
490    fn test_filter_headers() {
491        let mut extras = HashMap::new();
492        extras.insert("X-Tenant-ID".to_string(), json!("123"));
493        extras.insert("X-User-ID".to_string(), json!("456"));
494        extras.insert("Custom-Header".to_string(), json!("abc"));
495        extras.insert("Irrelevant-Header".to_string(), json!("ignore"));
496
497        // Test case-insensitive matching
498        let allowed = vec!["x-tenant-id".to_string(), "Custom-Header".to_string()];
499
500        filter_headers(&mut extras, &allowed);
501
502        assert!(extras.contains_key("X-Tenant-ID"));
503        assert!(extras.contains_key("Custom-Header"));
504        assert!(!extras.contains_key("X-User-ID"));
505        assert!(!extras.contains_key("Irrelevant-Header"));
506
507        // ensure values are preserved
508        assert_eq!(extras.get("X-Tenant-ID").unwrap(), &json!("123"));
509        assert_eq!(extras.get("Custom-Header").unwrap(), &json!("abc"));
510    }
511
512    #[tokio::test]
513    async fn test_call_handler_core_extras_are_session_scoped() {
514        use crate::app::AppStateBuilder;
515        use crate::call::{ActiveCallType, Command};
516        use crate::config::Config;
517
518        let mut config = Config::default();
519        config.udp_port = 0;
520        let app_state = AppStateBuilder::new()
521            .with_config(config)
522            .build()
523            .await
524            .expect("Failed to build app state");
525
526        let session_id = "test-session-scoped".to_string();
527        let cancel_token = CancellationToken::new();
528
529        // Pass extras directly as a parameter (not via global map)
530        let mut extras = HashMap::new();
531        extras.insert("X-Custom".to_string(), json!("value"));
532
533        let (_audio_sender, audio_receiver) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
534        let (command_sender, command_receiver) = tokio::sync::mpsc::unbounded_channel::<Command>();
535        let (event_sender, _event_receiver) =
536            tokio::sync::mpsc::unbounded_channel::<crate::event::SessionEvent>();
537
538        // Send a Hangup command immediately to end the call
539        command_sender
540            .send(Command::Hangup {
541                reason: None,
542                initiator: None,
543                headers: None,
544            })
545            .ok();
546        drop(command_sender);
547
548        // Run call_handler_core with extras passed directly
549        let final_extras = call_handler_core(
550            ActiveCallType::Sip,
551            session_id.clone(),
552            app_state.clone(),
553            cancel_token,
554            audio_receiver,
555            None,
556            false,
557            0,
558            command_receiver,
559            event_sender,
560            Some(extras), // extras passed directly
561            None,         // no playbook
562        )
563        .await;
564
565        // Verify that final extras are returned and contain our custom header
566        assert!(final_extras.is_some(), "final extras should be returned");
567        let extras = final_extras.unwrap();
568        assert_eq!(
569            extras.get("X-Custom"),
570            Some(&json!("value")),
571            "session-scoped extras should be preserved"
572        );
573    }
574}