Skip to main content

msg_gateway/
generic.rs

1use axum::{
2    Json,
3    extract::{
4        Path, State,
5        ws::{Message, WebSocket, WebSocketUpgrade},
6    },
7    http::header,
8    response::IntoResponse,
9};
10use futures_util::{SinkExt, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::{RwLock, broadcast};
15
16use crate::error::AppError;
17use crate::guardrail::GuardrailVerdict;
18use crate::message::{InboundMessage, MessageSource, UserInfo, WsOutboundMessage};
19use crate::server::AppState;
20
21/// Registry of active WebSocket connections
22/// Key: (credential_id, chat_id) → broadcast sender for that chat
23pub type WsRegistry = Arc<RwLock<HashMap<(String, String), broadcast::Sender<WsOutboundMessage>>>>;
24
25/// Create a new WebSocket registry
26pub fn new_ws_registry() -> WsRegistry {
27    Arc::new(RwLock::new(HashMap::new()))
28}
29
30/// Request body for generic chat inbound
31#[derive(Debug, Deserialize)]
32pub struct ChatRequest {
33    pub chat_id: String,
34    pub text: String,
35    pub from: ChatUser,
36    #[serde(default)]
37    pub files: Vec<InboundFileRef>,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct ChatUser {
42    pub id: String,
43    #[serde(default)]
44    pub display_name: Option<String>,
45}
46
47#[derive(Debug, Deserialize)]
48pub struct InboundFileRef {
49    pub url: String,
50    pub filename: String,
51    pub mime_type: String,
52    #[serde(default)]
53    pub auth_header: Option<String>,
54}
55
56/// Response for chat inbound
57#[derive(Debug, Serialize)]
58pub struct ChatResponse {
59    pub message_id: String,
60    pub timestamp: chrono::DateTime<chrono::Utc>,
61}
62
63/// POST /api/v1/chat/{credential_id}
64/// Generic adapter inbound - fire and forget
65pub async fn chat_inbound(
66    State(state): State<Arc<AppState>>,
67    Path(credential_id): Path<String>,
68    headers: axum::http::HeaderMap,
69    Json(payload): Json<ChatRequest>,
70) -> Result<impl IntoResponse, AppError> {
71    let config = state.config.read().await;
72
73    // Find credential
74    let credential = config
75        .credentials
76        .get(&credential_id)
77        .ok_or_else(|| AppError::CredentialNotFound(credential_id.clone()))?;
78
79    // Verify it's a generic adapter credential
80    if credential.adapter != "generic" {
81        return Err(AppError::Internal(format!(
82            "Credential {} is not a generic adapter credential",
83            credential_id
84        )));
85    }
86
87    // Verify token
88    let expected_token = &credential.token;
89    let auth_header = headers
90        .get(header::AUTHORIZATION)
91        .and_then(|v| v.to_str().ok());
92
93    match auth_header {
94        Some(auth) if auth.starts_with("Bearer ") => {
95            let token = &auth[7..];
96            if token != expected_token {
97                return Err(AppError::Unauthorized);
98            }
99        }
100        _ => return Err(AppError::Unauthorized),
101    }
102
103    // Check if credential is active
104    if !credential.active {
105        return Err(AppError::CredentialInactive(credential_id.clone()));
106    }
107
108    let route = credential.route.clone();
109
110    let backend_name = crate::backend::resolve_backend_name(credential, &config.gateway)
111        .ok_or_else(|| {
112            AppError::Internal("No backend configured for this credential".to_string())
113        })?;
114    let backend_cfg = config.backends.get(&backend_name).ok_or_else(|| {
115        AppError::Internal(format!("Backend '{}' not found in config", backend_name))
116    })?;
117    let gateway_ctx = crate::backend::GatewayContext {
118        gateway_url: format!("http://{}", config.gateway.listen),
119        send_token: config.auth.send_token.clone(),
120    };
121    let adapter = match crate::backend::create_adapter(
122        backend_cfg,
123        Some(&gateway_ctx),
124        credential.config.as_ref().or(backend_cfg.config.as_ref()),
125    ) {
126        Ok(a) => a,
127        Err(e) => {
128            return Err(AppError::Internal(format!(
129                "Failed to create backend adapter: {}",
130                e
131            )));
132        }
133    };
134    drop(config);
135
136    // Generate message ID
137    let message_id = format!("generic_{}", uuid::Uuid::new_v4());
138    let timestamp = chrono::Utc::now();
139
140    // Download and cache file attachments
141    let mut attachments = vec![];
142
143    // Warn once if files present but cache not configured
144    if state.file_cache.is_none() && !payload.files.is_empty() {
145        tracing::warn!("Files received but file cache not configured, skipping attachments");
146    }
147
148    for file_ref in &payload.files {
149        let Some(ref file_cache) = state.file_cache else {
150            continue;
151        };
152
153        match file_cache
154            .download_and_cache(
155                &file_ref.url,
156                file_ref.auth_header.as_deref(),
157                &file_ref.filename,
158                &file_ref.mime_type,
159            )
160            .await
161        {
162            Ok(cached) => {
163                attachments.push(crate::message::Attachment {
164                    filename: cached.filename.clone(),
165                    mime_type: cached.mime_type.clone(),
166                    size_bytes: cached.size_bytes,
167                    download_url: file_cache.get_download_url(&cached.file_id),
168                });
169                tracing::info!(
170                    file_id = %cached.file_id,
171                    filename = %cached.filename,
172                    "Generic inbound file cached"
173                );
174            }
175            Err(e) => {
176                tracing::warn!(
177                    url = %file_ref.url,
178                    error = %e,
179                    "Failed to cache generic inbound file attachment"
180                );
181            }
182        }
183    }
184
185    // Build normalized inbound message
186    let inbound = InboundMessage {
187        route,
188        credential_id: credential_id.clone(),
189        source: MessageSource {
190            protocol: "generic".to_string(),
191            chat_id: payload.chat_id.clone(),
192            message_id: message_id.clone(),
193            reply_to_message_id: None,
194            from: UserInfo {
195                id: payload.from.id,
196                username: None,
197                display_name: payload.from.display_name,
198            },
199        },
200        text: payload.text,
201        attachments,
202        timestamp,
203        extra_data: None,
204    };
205
206    let verdict = {
207        let engine = state.guardrail_engine.read().await;
208        engine.evaluate_inbound(&inbound)
209    };
210    match verdict {
211        GuardrailVerdict::Block { reject_message, .. } => {
212            return Err(AppError::Forbidden(reject_message));
213        }
214        GuardrailVerdict::Allow => {}
215    }
216
217    // Check if target server is down - buffer message instead of forwarding
218    let health_state = state.health_monitor.get_state().await;
219    if health_state == crate::health::HealthState::Down {
220        state.health_monitor.buffer_message(inbound).await;
221        tracing::info!(
222            credential_id = %credential_id,
223            message_id = %message_id,
224            "Message buffered (target server down)"
225        );
226    } else {
227        // Clone for the spawned task
228        let message_id_for_task = message_id.clone();
229        let credential_id_for_task = credential_id.clone();
230
231        // Forward to backend (fire and forget - spawn task)
232        tokio::spawn(async move {
233            match adapter.send_message(&inbound).await {
234                Ok(()) => {
235                    tracing::debug!(
236                        credential_id = %credential_id_for_task,
237                        message_id = %message_id_for_task,
238                        "Message forwarded to backend"
239                    );
240                }
241                Err(e) => {
242                    tracing::error!(
243                        credential_id = %credential_id_for_task,
244                        error = %e,
245                        "Failed to forward message to backend"
246                    );
247                }
248            }
249        });
250    }
251
252    // Return immediately (fire and forget)
253    Ok((
254        axum::http::StatusCode::ACCEPTED,
255        Json(ChatResponse {
256            message_id,
257            timestamp,
258        }),
259    ))
260}
261
262/// GET /ws/chat/{credential_id}/{chat_id}
263/// WebSocket upgrade for outbound messages
264pub async fn ws_handler(
265    State(state): State<Arc<AppState>>,
266    Path((credential_id, chat_id)): Path<(String, String)>,
267    headers: axum::http::HeaderMap,
268    ws: WebSocketUpgrade,
269) -> Result<impl IntoResponse, AppError> {
270    let config = state.config.read().await;
271
272    // Find credential
273    let credential = config
274        .credentials
275        .get(&credential_id)
276        .ok_or_else(|| AppError::CredentialNotFound(credential_id.clone()))?;
277
278    // Verify it's a generic adapter credential
279    if credential.adapter != "generic" {
280        return Err(AppError::Internal(format!(
281            "Credential {} is not a generic adapter credential",
282            credential_id
283        )));
284    }
285
286    // Verify token
287    let expected_token = &credential.token;
288    let auth_header = headers
289        .get(header::AUTHORIZATION)
290        .and_then(|v| v.to_str().ok());
291
292    match auth_header {
293        Some(auth) if auth.starts_with("Bearer ") => {
294            let token = &auth[7..];
295            if token != expected_token {
296                return Err(AppError::Unauthorized);
297            }
298        }
299        _ => return Err(AppError::Unauthorized),
300    }
301
302    if !credential.active {
303        return Err(AppError::CredentialInactive(credential_id.clone()));
304    }
305
306    drop(config);
307
308    let ws_registry = state.ws_registry.clone();
309    let cred_id = credential_id.clone();
310    let c_id = chat_id.clone();
311
312    Ok(ws.on_upgrade(move |socket| handle_ws(socket, ws_registry, cred_id, c_id)))
313}
314
315async fn handle_ws(
316    socket: WebSocket,
317    registry: WsRegistry,
318    credential_id: String,
319    chat_id: String,
320) {
321    let (sender, mut receiver) = socket.split();
322    let sender = Arc::new(tokio::sync::Mutex::new(sender));
323
324    // Create or get broadcast channel for this chat
325    let mut rx = {
326        let mut reg = registry.write().await;
327        let key = (credential_id.clone(), chat_id.clone());
328
329        let tx = reg.entry(key).or_insert_with(|| {
330            let (tx, _) = broadcast::channel(100);
331            tx
332        });
333
334        tx.subscribe()
335    };
336
337    tracing::info!(
338        credential_id = %credential_id,
339        chat_id = %chat_id,
340        "WebSocket connected"
341    );
342
343    // Spawn task to receive messages from broadcast and send to WebSocket
344    let sender_clone = sender.clone();
345    let send_task = tokio::spawn(async move {
346        while let Ok(msg) = rx.recv().await {
347            let json = serde_json::to_string(&msg).unwrap();
348            let mut s = sender_clone.lock().await;
349            if s.send(Message::Text(json.into())).await.is_err() {
350                break;
351            }
352        }
353    });
354
355    // Handle incoming WebSocket messages (for ping/pong, close, etc.)
356    while let Some(msg) = receiver.next().await {
357        match msg {
358            Ok(Message::Close(_)) => break,
359            Ok(Message::Ping(_)) => {
360                // Pong is handled automatically by axum
361                tracing::trace!("Received ping");
362            }
363            Err(e) => {
364                tracing::debug!(error = %e, "WebSocket error");
365                break;
366            }
367            _ => {}
368        }
369    }
370
371    // Cleanup
372    send_task.abort();
373
374    // Remove from registry if no more subscribers
375    {
376        let mut reg = registry.write().await;
377        let key = (credential_id.clone(), chat_id.clone());
378        if let Some(tx) = reg.get(&key)
379            && tx.receiver_count() == 0
380        {
381            reg.remove(&key);
382        }
383    }
384
385    tracing::info!(
386        credential_id = %credential_id,
387        chat_id = %chat_id,
388        "WebSocket disconnected"
389    );
390}
391
392/// Send a message to a WebSocket client (called from /api/v1/send)
393pub async fn send_to_ws(
394    registry: &WsRegistry,
395    credential_id: &str,
396    chat_id: &str,
397    message: WsOutboundMessage,
398) -> bool {
399    let reg = registry.read().await;
400    let key = (credential_id.to_string(), chat_id.to_string());
401
402    if let Some(tx) = reg.get(&key) {
403        match tx.send(message) {
404            Ok(_) => {
405                tracing::debug!(
406                    credential_id = %credential_id,
407                    chat_id = %chat_id,
408                    "Message sent to WebSocket"
409                );
410                true
411            }
412            Err(_) => {
413                tracing::debug!(
414                    credential_id = %credential_id,
415                    chat_id = %chat_id,
416                    "No active WebSocket subscribers"
417                );
418                false
419            }
420        }
421    } else {
422        tracing::debug!(
423            credential_id = %credential_id,
424            chat_id = %chat_id,
425            "No WebSocket connection for this chat"
426        );
427        false
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    fn make_ws_message(text: &str, message_id: &str) -> WsOutboundMessage {
436        WsOutboundMessage {
437            text: text.to_string(),
438            timestamp: chrono::Utc::now(),
439            message_id: message_id.to_string(),
440            file_urls: vec![],
441        }
442    }
443
444    // ==================== WsRegistry Tests ====================
445
446    #[test]
447    fn test_new_ws_registry() {
448        let registry = new_ws_registry();
449        // Should be empty
450        let rt = tokio::runtime::Runtime::new().unwrap();
451        rt.block_on(async {
452            let reg = registry.read().await;
453            assert!(reg.is_empty());
454        });
455    }
456
457    #[tokio::test]
458    async fn test_ws_registry_add_channel() {
459        let registry = new_ws_registry();
460
461        // Add a channel
462        {
463            let mut reg = registry.write().await;
464            let (tx, _) = broadcast::channel::<WsOutboundMessage>(100);
465            reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
466        }
467
468        // Verify it exists
469        {
470            let reg = registry.read().await;
471            assert!(reg.contains_key(&("cred1".to_string(), "chat1".to_string())));
472        }
473    }
474
475    #[tokio::test]
476    async fn test_ws_registry_remove_channel() {
477        let registry = new_ws_registry();
478
479        // Add and remove
480        {
481            let mut reg = registry.write().await;
482            let (tx, _) = broadcast::channel::<WsOutboundMessage>(100);
483            reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
484        }
485
486        {
487            let mut reg = registry.write().await;
488            reg.remove(&("cred1".to_string(), "chat1".to_string()));
489        }
490
491        {
492            let reg = registry.read().await;
493            assert!(!reg.contains_key(&("cred1".to_string(), "chat1".to_string())));
494        }
495    }
496
497    // ==================== send_to_ws Tests ====================
498
499    #[tokio::test]
500    async fn test_send_to_ws_no_connection() {
501        let registry = new_ws_registry();
502
503        let message = make_ws_message("Hello", "msg_123");
504
505        let result = send_to_ws(&registry, "cred1", "chat1", message).await;
506        assert!(!result);
507    }
508
509    #[tokio::test]
510    async fn test_send_to_ws_with_subscriber() {
511        let registry = new_ws_registry();
512
513        // Add a channel with subscriber
514        let mut rx = {
515            let mut reg = registry.write().await;
516            let (tx, rx) = broadcast::channel::<WsOutboundMessage>(100);
517            reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
518            rx
519        };
520
521        let message = make_ws_message("Hello", "msg_123");
522
523        let result = send_to_ws(&registry, "cred1", "chat1", message).await;
524        assert!(result);
525
526        // Verify message was received
527        let received = rx.recv().await.unwrap();
528        assert_eq!(received.text, "Hello");
529        assert_eq!(received.message_id, "msg_123");
530    }
531
532    #[tokio::test]
533    async fn test_send_to_ws_no_subscribers() {
534        let registry = new_ws_registry();
535
536        // Add a channel without keeping the receiver (so no subscribers)
537        {
538            let mut reg = registry.write().await;
539            let (tx, _rx) = broadcast::channel::<WsOutboundMessage>(100);
540            // Drop the receiver immediately
541            drop(_rx);
542            reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
543        }
544
545        let message = make_ws_message("Hello", "msg_456");
546
547        let result = send_to_ws(&registry, "cred1", "chat1", message).await;
548        // Should return false because no active subscribers
549        assert!(!result);
550    }
551
552    // ==================== ChatRequest Tests ====================
553
554    #[test]
555    fn test_chat_request_parse() {
556        let json = r#"{
557            "chat_id": "12345",
558            "text": "Hello, world!",
559            "from": {
560                "id": "user_1",
561                "display_name": "Test User"
562            }
563        }"#;
564
565        let req: ChatRequest = serde_json::from_str(json).unwrap();
566        assert_eq!(req.chat_id, "12345");
567        assert_eq!(req.text, "Hello, world!");
568        assert_eq!(req.from.id, "user_1");
569        assert_eq!(req.from.display_name, Some("Test User".to_string()));
570    }
571
572    #[test]
573    fn test_chat_request_minimal() {
574        let json = r#"{
575            "chat_id": "12345",
576            "text": "Hello",
577            "from": {"id": "user_1"}
578        }"#;
579
580        let req: ChatRequest = serde_json::from_str(json).unwrap();
581        assert_eq!(req.chat_id, "12345");
582        assert_eq!(req.from.id, "user_1");
583        assert!(req.from.display_name.is_none());
584    }
585
586    // ==================== ChatUser Tests ====================
587
588    #[test]
589    fn test_chat_user_parse() {
590        let json = r#"{"id": "user_123", "display_name": "John Doe"}"#;
591        let user: ChatUser = serde_json::from_str(json).unwrap();
592        assert_eq!(user.id, "user_123");
593        assert_eq!(user.display_name, Some("John Doe".to_string()));
594    }
595
596    #[test]
597    fn test_chat_user_minimal() {
598        let json = r#"{"id": "user_123"}"#;
599        let user: ChatUser = serde_json::from_str(json).unwrap();
600        assert_eq!(user.id, "user_123");
601        assert!(user.display_name.is_none());
602    }
603
604    // ==================== ChatResponse Tests ====================
605
606    #[test]
607    fn test_chat_response_serialize() {
608        let response = ChatResponse {
609            message_id: "msg_123".to_string(),
610            timestamp: chrono::Utc::now(),
611        };
612
613        let json = serde_json::to_string(&response).unwrap();
614        assert!(json.contains("\"message_id\":\"msg_123\""));
615        assert!(json.contains("\"timestamp\""));
616    }
617
618    // ==================== Multiple Channels Tests ====================
619
620    #[tokio::test]
621    async fn test_ws_registry_multiple_chats() {
622        let registry = new_ws_registry();
623
624        // Add multiple channels
625        let mut rx1;
626        let mut rx2;
627        {
628            let mut reg = registry.write().await;
629            let (tx1, r1) = broadcast::channel::<WsOutboundMessage>(100);
630            let (tx2, r2) = broadcast::channel::<WsOutboundMessage>(100);
631            reg.insert(("cred1".to_string(), "chat1".to_string()), tx1);
632            reg.insert(("cred1".to_string(), "chat2".to_string()), tx2);
633            rx1 = r1;
634            rx2 = r2;
635        }
636
637        // Send to chat1
638        let msg1 = make_ws_message("Message for chat1", "msg_1");
639        assert!(send_to_ws(&registry, "cred1", "chat1", msg1).await);
640
641        // Send to chat2
642        let msg2 = make_ws_message("Message for chat2", "msg_2");
643        assert!(send_to_ws(&registry, "cred1", "chat2", msg2).await);
644
645        // Verify correct routing
646        let received1 = rx1.recv().await.unwrap();
647        assert_eq!(received1.text, "Message for chat1");
648
649        let received2 = rx2.recv().await.unwrap();
650        assert_eq!(received2.text, "Message for chat2");
651    }
652
653    #[tokio::test]
654    async fn test_ws_registry_different_credentials() {
655        let registry = new_ws_registry();
656
657        // Add channels for different credentials
658        let mut rx_cred1;
659        let mut rx_cred2;
660        {
661            let mut reg = registry.write().await;
662            let (tx1, r1) = broadcast::channel::<WsOutboundMessage>(100);
663            let (tx2, r2) = broadcast::channel::<WsOutboundMessage>(100);
664            reg.insert(("cred1".to_string(), "chat".to_string()), tx1);
665            reg.insert(("cred2".to_string(), "chat".to_string()), tx2);
666            rx_cred1 = r1;
667            rx_cred2 = r2;
668        }
669
670        // Send to cred1
671        let msg = make_ws_message("For cred1", "msg_1");
672        assert!(send_to_ws(&registry, "cred1", "chat", msg).await);
673
674        // Only cred1 should receive
675        let received = rx_cred1.recv().await.unwrap();
676        assert_eq!(received.text, "For cred1");
677
678        // cred2 should not have received (try_recv returns error)
679        assert!(rx_cred2.try_recv().is_err());
680    }
681
682    #[tokio::test]
683    async fn test_send_to_ws_multiple_messages() {
684        let registry = new_ws_registry();
685
686        let mut rx = {
687            let mut reg = registry.write().await;
688            let (tx, rx) = broadcast::channel::<WsOutboundMessage>(100);
689            reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
690            rx
691        };
692
693        // Send multiple messages
694        for i in 1..=5 {
695            let message = make_ws_message(&format!("Message {}", i), &format!("msg_{}", i));
696            let result = send_to_ws(&registry, "cred1", "chat1", message).await;
697            assert!(result);
698        }
699
700        // Verify all messages received in order
701        for i in 1..=5 {
702            let received = rx.recv().await.unwrap();
703            assert_eq!(received.text, format!("Message {}", i));
704            assert_eq!(received.message_id, format!("msg_{}", i));
705        }
706    }
707
708    #[tokio::test]
709    async fn test_ws_registry_broadcast_to_multiple_subscribers() {
710        let registry = new_ws_registry();
711
712        // Add a channel with multiple subscribers
713        let (mut rx1, mut rx2);
714        {
715            let mut reg = registry.write().await;
716            let (tx, r1) = broadcast::channel::<WsOutboundMessage>(100);
717            rx1 = r1;
718            rx2 = tx.subscribe();
719            reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
720        }
721
722        let message = make_ws_message("Broadcast message", "msg_broadcast");
723        let result = send_to_ws(&registry, "cred1", "chat1", message).await;
724        assert!(result);
725
726        // Both subscribers should receive the message
727        let received1 = rx1.recv().await.unwrap();
728        let received2 = rx2.recv().await.unwrap();
729
730        assert_eq!(received1.text, "Broadcast message");
731        assert_eq!(received2.text, "Broadcast message");
732    }
733
734    #[test]
735    fn test_chat_request_with_files() {
736        let json = r#"{
737            "chat_id": "123",
738            "text": "see attached",
739            "from": {"id": "u1"},
740            "files": [
741                {"url": "https://example.com/img.jpg", "filename": "img.jpg", "mime_type": "image/jpeg"}
742            ]
743        }"#;
744        let req: ChatRequest = serde_json::from_str(json).unwrap();
745        assert_eq!(req.files.len(), 1);
746        assert_eq!(req.files[0].url, "https://example.com/img.jpg");
747        assert_eq!(req.files[0].filename, "img.jpg");
748        assert!(req.files[0].auth_header.is_none());
749    }
750
751    #[test]
752    fn test_chat_request_no_files() {
753        let json = r#"{"chat_id": "123", "text": "hello", "from": {"id": "u1"}}"#;
754        let req: ChatRequest = serde_json::from_str(json).unwrap();
755        assert!(req.files.is_empty());
756    }
757}