Skip to main content

active_call/handler/
handler.rs

1use crate::{
2    app::AppState,
3    call::{
4        ActiveCall, ActiveCallType, Command,
5        active_call::{ActiveCallGuard, CallParams},
6    },
7    handler::playbook,
8    playbook::{Playbook, PlaybookRunner},
9};
10use crate::{event::SessionEvent, media::track::TrackConfig};
11use axum::{
12    Json, Router,
13    extract::{Path, Query, State, WebSocketUpgrade, ws::Message},
14    response::{IntoResponse, Response},
15    routing::get,
16};
17use bytes::Bytes;
18use chrono::Utc;
19use futures::{SinkExt, StreamExt};
20use rustrtc::IceServer;
21use serde_json::json;
22use std::{path::PathBuf, sync::Arc, time::Duration};
23use tokio::{join, select};
24use tokio_util::sync::CancellationToken;
25use tracing::{debug, info, trace, warn};
26use uuid::Uuid;
27
28fn filter_headers(
29    extras: &mut std::collections::HashMap<String, serde_json::Value>,
30    allowed_headers: &[String],
31) {
32    extras.retain(|k, _| allowed_headers.iter().any(|h| h.eq_ignore_ascii_case(k)));
33}
34
35pub fn call_router() -> Router<AppState> {
36    let r = Router::new()
37        .route("/call", get(ws_handler))
38        .route("/call/webrtc", get(webrtc_handler))
39        .route("/call/sip", get(sip_handler))
40        .route("/list", get(list_active_calls))
41        .route("/kill/{id}", get(kill_active_call));
42    r
43}
44
45pub fn iceservers_router() -> Router<AppState> {
46    let r = Router::new();
47    r.route("/iceservers", get(get_iceservers))
48}
49
50pub fn playbook_router() -> Router<AppState> {
51    Router::new()
52        .route("/api/playbooks", get(playbook::list_playbooks))
53        .route(
54            "/api/playbooks/{name}",
55            get(playbook::get_playbook).post(playbook::save_playbook),
56        )
57        .route(
58            "/api/playbook/run",
59            axum::routing::post(playbook::run_playbook),
60        )
61        .route("/api/records", get(playbook::list_records))
62}
63
64pub async fn ws_handler(
65    ws: WebSocketUpgrade,
66    State(state): State<AppState>,
67    Query(params): Query<CallParams>,
68) -> Response {
69    call_handler(ActiveCallType::WebSocket, ws, state, params).await
70}
71
72pub async fn sip_handler(
73    ws: WebSocketUpgrade,
74    State(state): State<AppState>,
75    Query(params): Query<CallParams>,
76) -> Response {
77    call_handler(ActiveCallType::Sip, ws, state, params).await
78}
79
80pub async fn webrtc_handler(
81    ws: WebSocketUpgrade,
82    State(state): State<AppState>,
83    Query(params): Query<CallParams>,
84) -> Response {
85    call_handler(ActiveCallType::Webrtc, ws, state, params).await
86}
87
88/// Core call handling logic that works with either WebSocket or mpsc channels
89pub async fn call_handler_core(
90    call_type: ActiveCallType,
91    session_id: String,
92    app_state: AppState,
93    cancel_token: CancellationToken,
94    audio_receiver: tokio::sync::mpsc::UnboundedReceiver<Bytes>,
95    server_side_track: Option<String>,
96    dump_events: bool,
97    ping_interval: u64,
98    mut command_receiver: tokio::sync::mpsc::UnboundedReceiver<Command>,
99    event_sender_to_client: tokio::sync::mpsc::UnboundedSender<crate::event::SessionEvent>,
100) {
101    let _cancel_guard = cancel_token.clone().drop_guard();
102    let track_config = TrackConfig::default();
103
104    // Check for pending params (extracted headers)
105    let extras = {
106        let mut pending = app_state.pending_params.lock().await;
107        pending.remove(&session_id)
108    };
109
110    let active_call = Arc::new(ActiveCall::new(
111        call_type.clone(),
112        cancel_token.clone(),
113        session_id.clone(),
114        app_state.invitation.clone(),
115        app_state.clone(),
116        track_config,
117        Some(audio_receiver),
118        dump_events,
119        server_side_track,
120        extras,
121        None,
122    ));
123
124    // Check for pending playbook
125    {
126        let mut pending = app_state.pending_playbooks.lock().await;
127        if let Some(name_or_content) = pending.remove(&session_id) {
128            let playbook_result = if name_or_content.trim().starts_with("---") {
129                Playbook::parse(&name_or_content)
130            } else {
131                // If path already contains config/playbook, use it as-is; otherwise prepend it
132                let path = if name_or_content.starts_with("config/playbook/") {
133                    PathBuf::from(&name_or_content)
134                } else {
135                    PathBuf::from("config/playbook").join(&name_or_content)
136                };
137                Playbook::load(path).await
138            };
139
140            match playbook_result {
141                Ok(mut playbook) => {
142                    // Filter extracted headers if configured (only for SIP calls)
143                    if call_type == ActiveCallType::Sip {
144                        if let Some(sip_config) = &playbook.config.sip {
145                            if let Some(allowed_headers) = &sip_config.extract_headers {
146                                let mut state = active_call.call_state.write().await;
147                                if let Some(extras) = &mut state.extras {
148                                    filter_headers(extras, allowed_headers);
149                                    // Store the list of SIP header keys for later template rendering
150                                    let header_keys: Vec<String> = extras
151                                        .keys()
152                                        .filter(|k| !k.starts_with('_'))
153                                        .cloned()
154                                        .collect();
155                                    extras.insert(
156                                        "_sip_header_keys".to_string(),
157                                        serde_json::to_value(&header_keys).unwrap_or_default(),
158                                    );
159                                    if let Ok(result) = playbook.render(extras) {
160                                        playbook = result;
161                                    }
162                                }
163                            }
164                        }
165                    }
166
167                    match PlaybookRunner::new(playbook, active_call.clone()) {
168                        Ok(runner) => {
169                            crate::spawn(async move {
170                                runner.run().await;
171                            });
172                            let display_name = if name_or_content.trim().starts_with("---") {
173                                "custom content"
174                            } else {
175                                &name_or_content
176                            };
177                            info!(session_id, "Playbook runner started for {}", display_name);
178                        }
179                        Err(e) => {
180                            let display_name = if name_or_content.trim().starts_with("---") {
181                                "custom content"
182                            } else {
183                                &name_or_content
184                            };
185                            warn!(
186                                session_id,
187                                "Failed to create runner {}: {}", display_name, e
188                            )
189                        }
190                    }
191                }
192                Err(e) => {
193                    let display_name = if name_or_content.trim().starts_with("---") {
194                        "custom content"
195                    } else {
196                        &name_or_content
197                    };
198                    warn!(
199                        session_id,
200                        "Failed to load playbook {}: {}", display_name, e
201                    );
202                    let event = SessionEvent::Error {
203                        timestamp: crate::media::get_timestamp(),
204                        track_id: session_id.clone(),
205                        sender: "playbook".to_string(),
206                        error: format!("{}", e),
207                        code: None,
208                    };
209                    event_sender_to_client.send(event).ok();
210                    return;
211                }
212            }
213        }
214    }
215
216    let recv_commands_loop = async {
217        while let Some(command) = command_receiver.recv().await {
218            if let Err(_) = active_call.enqueue_command(command).await {
219                break;
220            }
221        }
222    };
223
224    let mut event_receiver = active_call.event_sender.subscribe();
225    let send_events_loop = async {
226        loop {
227            match event_receiver.recv().await {
228                Ok(event) => {
229                    if let Err(_) = event_sender_to_client.send(event) {
230                        break;
231                    }
232                }
233                Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
234                Err(_) => break,
235            }
236        }
237    };
238
239    let send_ping_loop = async {
240        if ping_interval == 0 {
241            active_call.cancel_token.cancelled().await;
242            return;
243        }
244        let mut ticker = tokio::time::interval(Duration::from_secs(ping_interval));
245        loop {
246            ticker.tick().await;
247            let payload = Utc::now().to_rfc3339();
248            let event = SessionEvent::Ping {
249                timestamp: crate::media::get_timestamp(),
250                payload: Some(payload),
251            };
252            if let Err(_) = active_call.event_sender.send(event) {
253                break;
254            }
255        }
256    };
257
258    let guard = ActiveCallGuard::new(active_call.clone());
259    info!(
260        session_id,
261        active_calls = guard.active_calls,
262        ?call_type,
263        "new call started"
264    );
265    let receiver = active_call.new_receiver();
266
267    let (r, _) = join! {
268        active_call.serve(receiver),
269        async {
270            select!{
271                _ = send_ping_loop => {},
272                _ = cancel_token.cancelled() => {},
273                _ = send_events_loop => { },
274                _ = recv_commands_loop => {
275                    info!(session_id, "Command receiver closed");
276                },
277            }
278            cancel_token.cancel();
279        }
280    };
281    // drain events
282    while let Ok(event) = event_receiver.try_recv() {
283        if let Err(_) = event_sender_to_client.send(event) {
284            break;
285        }
286    }
287    match r {
288        Ok(_) => info!(session_id, "call ended successfully"),
289        Err(e) => warn!(session_id, "call ended with error: {}", e),
290    }
291
292    // Write back final extras to pending_params for post-call processing (e.g. BYE headers)
293    if call_type == ActiveCallType::Sip {
294        let state = active_call.call_state.read().await;
295        if let Some(extras) = &state.extras {
296            let mut pending = app_state.pending_params.lock().await;
297            pending.insert(session_id.clone(), extras.clone());
298        }
299    }
300
301    active_call.cleanup().await.ok();
302    debug!(session_id, "Call handler core completed");
303}
304
305pub async fn call_handler(
306    call_type: ActiveCallType,
307    ws: WebSocketUpgrade,
308    app_state: AppState,
309    params: CallParams,
310) -> Response {
311    let session_id = params
312        .id
313        .unwrap_or_else(|| format!("s.{}", Uuid::new_v4().to_string()));
314    let server_side_track = params.server_side_track.clone();
315    let dump_events = params.dump_events.unwrap_or(true);
316    let ping_interval = params.ping_interval.unwrap_or(20);
317
318    let resp = ws.on_upgrade(move |socket| async move {
319        let (mut ws_sender, mut ws_receiver) = socket.split();
320        let (audio_sender, audio_receiver) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
321        let (command_sender, command_receiver) = tokio::sync::mpsc::unbounded_channel::<Command>();
322        let (event_sender_to_client, mut event_receiver_from_core) =
323            tokio::sync::mpsc::unbounded_channel::<crate::event::SessionEvent>();
324        let cancel_token = CancellationToken::new();
325
326        // Start core handler in background
327        let session_id_clone = session_id.clone();
328        let app_state_clone = app_state.clone();
329        let cancel_token_clone = cancel_token.clone();
330        crate::spawn(call_handler_core(
331            call_type,
332            session_id_clone,
333            app_state_clone,
334            cancel_token_clone,
335            audio_receiver,
336            server_side_track,
337            dump_events,
338            ping_interval.into(),
339            command_receiver,
340            event_sender_to_client,
341        ));
342
343        // Handle WebSocket I/O
344        let recv_from_ws_loop = async {
345            while let Some(Ok(message)) = ws_receiver.next().await {
346                match message {
347                    Message::Text(text) => {
348                        let command = match serde_json::from_str::<Command>(&text) {
349                            Ok(cmd) => cmd,
350                            Err(e) => {
351                                warn!(session_id, %text, "Failed to parse command {}", e);
352                                continue;
353                            }
354                        };
355                        if let Err(_) = command_sender.send(command) {
356                            break;
357                        }
358                    }
359                    Message::Binary(bin) => {
360                        audio_sender.send(bin.into()).ok();
361                    }
362                    Message::Close(_) => {
363                        info!(session_id, "WebSocket closed by client");
364                        break;
365                    }
366                    _ => {}
367                }
368            }
369        };
370
371        let send_to_ws_loop = async {
372            while let Some(event) = event_receiver_from_core.recv().await {
373                trace!(session_id, %event, "Sending WS message");
374                let message = match event.into_ws_message() {
375                    Ok(msg) => msg,
376                    Err(e) => {
377                        warn!(session_id, error=%e, "Failed to serialize event to WS message");
378                        continue;
379                    }
380                };
381                if let Err(_) = ws_sender.send(message).await {
382                    info!(session_id, "WebSocket send failed, closing");
383                    break;
384                }
385            }
386        };
387
388        select! {
389            _ = recv_from_ws_loop => {
390                info!(session_id, "WebSocket receive loop ended");
391            },
392            _ = send_to_ws_loop => {
393                info!(session_id, "WebSocket send loop ended");
394            },
395        }
396
397        cancel_token.cancel();
398        ws_sender.flush().await.ok();
399        ws_sender.close().await.ok();
400        debug!(session_id, "WebSocket connection closed");
401    });
402    resp
403}
404
405pub(crate) async fn get_iceservers(State(state): State<AppState>) -> Response {
406    if let Some(ice_servers) = state.config.ice_servers.as_ref() {
407        return Json(ice_servers).into_response();
408    }
409    Json(vec![IceServer {
410        urls: vec!["stun:stun.l.google.com:19302".to_string()],
411        ..Default::default()
412    }])
413    .into_response()
414}
415
416pub(crate) async fn list_active_calls(State(state): State<AppState>) -> Response {
417    let calls = state
418        .active_calls
419        .lock()
420        .unwrap()
421        .iter()
422        .map(|(_, c)| {
423            if let Ok(cs) = c.call_state.try_read() {
424                json!({
425                    "id": c.session_id,
426                    "callType": c.call_type,
427                    "cs.option": cs.option,
428                    "ringTime": cs.ring_time,
429                    "startTime": cs.answer_time,
430                })
431            } else {
432                json!({
433                    "id": c.session_id,
434                    "callType": c.call_type,
435                    "status": "locked",
436                })
437            }
438        })
439        .collect::<Vec<_>>();
440    Json(serde_json::json!({ "active_calls": calls })).into_response()
441}
442
443pub(crate) async fn kill_active_call(
444    Path(id): Path<String>,
445    State(state): State<AppState>,
446) -> Response {
447    let active_calls = state.active_calls.lock().unwrap();
448    if let Some(call) = active_calls.get(&id) {
449        call.cancel_token.cancel();
450        Json(serde_json::json!({ "status": "killed", "id": id })).into_response()
451    } else {
452        Json(serde_json::json!({ "status": "not_found", "id": id })).into_response()
453    }
454}
455
456trait IntoWsMessage {
457    fn into_ws_message(self) -> Result<Message, serde_json::Error>;
458}
459
460impl IntoWsMessage for crate::event::SessionEvent {
461    fn into_ws_message(self) -> Result<Message, serde_json::Error> {
462        match self {
463            SessionEvent::Binary { data, .. } => Ok(Message::Binary(data.into())),
464            SessionEvent::Ping { timestamp, payload } => {
465                let payload = payload.unwrap_or_else(|| timestamp.to_string());
466                Ok(Message::Ping(payload.into()))
467            }
468            event => serde_json::to_string(&event).map(|payload| Message::Text(payload.into())),
469        }
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use serde_json::json;
477    use std::collections::HashMap;
478
479    #[test]
480    fn test_filter_headers() {
481        let mut extras = HashMap::new();
482        extras.insert("X-Tenant-ID".to_string(), json!("123"));
483        extras.insert("X-User-ID".to_string(), json!("456"));
484        extras.insert("Custom-Header".to_string(), json!("abc"));
485        extras.insert("Irrelevant-Header".to_string(), json!("ignore"));
486
487        // Test case-insensitive matching
488        let allowed = vec!["x-tenant-id".to_string(), "Custom-Header".to_string()];
489
490        filter_headers(&mut extras, &allowed);
491
492        assert!(extras.contains_key("X-Tenant-ID"));
493        assert!(extras.contains_key("Custom-Header"));
494        assert!(!extras.contains_key("X-User-ID"));
495        assert!(!extras.contains_key("Irrelevant-Header"));
496
497        // ensure values are preserved
498        assert_eq!(extras.get("X-Tenant-ID").unwrap(), &json!("123"));
499        assert_eq!(extras.get("Custom-Header").unwrap(), &json!("abc"));
500    }
501}