1use axum::{
2 Json,
3 extract::{
4 Path, State,
5 ws::{Message, WebSocket, WebSocketUpgrade},
6 },
7 http::header,
8 response::IntoResponse,
9};
10use futures_util::{SinkExt, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::{RwLock, broadcast};
15
16use crate::error::AppError;
17use crate::guardrail::GuardrailVerdict;
18use crate::message::{InboundMessage, MessageSource, UserInfo, WsOutboundMessage};
19use crate::server::AppState;
20
21pub type WsRegistry = Arc<RwLock<HashMap<(String, String), broadcast::Sender<WsOutboundMessage>>>>;
24
25pub fn new_ws_registry() -> WsRegistry {
27 Arc::new(RwLock::new(HashMap::new()))
28}
29
30#[derive(Debug, Deserialize)]
32pub struct ChatRequest {
33 pub chat_id: String,
34 pub text: String,
35 pub from: ChatUser,
36 #[serde(default)]
37 pub files: Vec<InboundFileRef>,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct ChatUser {
42 pub id: String,
43 #[serde(default)]
44 pub display_name: Option<String>,
45}
46
47#[derive(Debug, Deserialize)]
48pub struct InboundFileRef {
49 pub url: String,
50 pub filename: String,
51 pub mime_type: String,
52 #[serde(default)]
53 pub auth_header: Option<String>,
54}
55
56#[derive(Debug, Serialize)]
58pub struct ChatResponse {
59 pub message_id: String,
60 pub timestamp: chrono::DateTime<chrono::Utc>,
61}
62
63pub async fn chat_inbound(
66 State(state): State<Arc<AppState>>,
67 Path(credential_id): Path<String>,
68 headers: axum::http::HeaderMap,
69 Json(payload): Json<ChatRequest>,
70) -> Result<impl IntoResponse, AppError> {
71 let config = state.config.read().await;
72
73 let credential = config
75 .credentials
76 .get(&credential_id)
77 .ok_or_else(|| AppError::CredentialNotFound(credential_id.clone()))?;
78
79 if credential.adapter != "generic" {
81 return Err(AppError::Internal(format!(
82 "Credential {} is not a generic adapter credential",
83 credential_id
84 )));
85 }
86
87 let expected_token = &credential.token;
89 let auth_header = headers
90 .get(header::AUTHORIZATION)
91 .and_then(|v| v.to_str().ok());
92
93 match auth_header {
94 Some(auth) if auth.starts_with("Bearer ") => {
95 let token = &auth[7..];
96 if token != expected_token {
97 return Err(AppError::Unauthorized);
98 }
99 }
100 _ => return Err(AppError::Unauthorized),
101 }
102
103 if !credential.active {
105 return Err(AppError::CredentialInactive(credential_id.clone()));
106 }
107
108 let route = credential.route.clone();
109
110 let backend_name = crate::backend::resolve_backend_name(credential, &config.gateway)
111 .ok_or_else(|| {
112 AppError::Internal("No backend configured for this credential".to_string())
113 })?;
114 let backend_cfg = config.backends.get(&backend_name).ok_or_else(|| {
115 AppError::Internal(format!("Backend '{}' not found in config", backend_name))
116 })?;
117 let gateway_ctx = crate::backend::GatewayContext {
118 gateway_url: format!("http://{}", config.gateway.listen),
119 send_token: config.auth.send_token.clone(),
120 };
121 let adapter = match crate::backend::create_adapter(
122 backend_cfg,
123 Some(&gateway_ctx),
124 credential.config.as_ref().or(backend_cfg.config.as_ref()),
125 ) {
126 Ok(a) => a,
127 Err(e) => {
128 return Err(AppError::Internal(format!(
129 "Failed to create backend adapter: {}",
130 e
131 )));
132 }
133 };
134 drop(config);
135
136 let message_id = format!("generic_{}", uuid::Uuid::new_v4());
138 let timestamp = chrono::Utc::now();
139
140 let mut attachments = vec![];
142
143 if state.file_cache.is_none() && !payload.files.is_empty() {
145 tracing::warn!("Files received but file cache not configured, skipping attachments");
146 }
147
148 for file_ref in &payload.files {
149 let Some(ref file_cache) = state.file_cache else {
150 continue;
151 };
152
153 match file_cache
154 .download_and_cache(
155 &file_ref.url,
156 file_ref.auth_header.as_deref(),
157 &file_ref.filename,
158 &file_ref.mime_type,
159 )
160 .await
161 {
162 Ok(cached) => {
163 attachments.push(crate::message::Attachment {
164 filename: cached.filename.clone(),
165 mime_type: cached.mime_type.clone(),
166 size_bytes: cached.size_bytes,
167 download_url: file_cache.get_download_url(&cached.file_id),
168 });
169 tracing::info!(
170 file_id = %cached.file_id,
171 filename = %cached.filename,
172 "Generic inbound file cached"
173 );
174 }
175 Err(e) => {
176 tracing::warn!(
177 url = %file_ref.url,
178 error = %e,
179 "Failed to cache generic inbound file attachment"
180 );
181 }
182 }
183 }
184
185 let inbound = InboundMessage {
187 route,
188 credential_id: credential_id.clone(),
189 source: MessageSource {
190 protocol: "generic".to_string(),
191 chat_id: payload.chat_id.clone(),
192 message_id: message_id.clone(),
193 reply_to_message_id: None,
194 from: UserInfo {
195 id: payload.from.id,
196 username: None,
197 display_name: payload.from.display_name,
198 },
199 },
200 text: payload.text,
201 attachments,
202 timestamp,
203 extra_data: None,
204 };
205
206 let verdict = {
207 let engine = state.guardrail_engine.read().await;
208 engine.evaluate_inbound(&inbound)
209 };
210 match verdict {
211 GuardrailVerdict::Block { reject_message, .. } => {
212 return Err(AppError::Forbidden(reject_message));
213 }
214 GuardrailVerdict::Allow => {}
215 }
216
217 let health_state = state.health_monitor.get_state().await;
219 if health_state == crate::health::HealthState::Down {
220 state.health_monitor.buffer_message(inbound).await;
221 tracing::info!(
222 credential_id = %credential_id,
223 message_id = %message_id,
224 "Message buffered (target server down)"
225 );
226 } else {
227 let message_id_for_task = message_id.clone();
229 let credential_id_for_task = credential_id.clone();
230
231 tokio::spawn(async move {
233 match adapter.send_message(&inbound).await {
234 Ok(()) => {
235 tracing::debug!(
236 credential_id = %credential_id_for_task,
237 message_id = %message_id_for_task,
238 "Message forwarded to backend"
239 );
240 }
241 Err(e) => {
242 tracing::error!(
243 credential_id = %credential_id_for_task,
244 error = %e,
245 "Failed to forward message to backend"
246 );
247 }
248 }
249 });
250 }
251
252 Ok((
254 axum::http::StatusCode::ACCEPTED,
255 Json(ChatResponse {
256 message_id,
257 timestamp,
258 }),
259 ))
260}
261
262pub async fn ws_handler(
265 State(state): State<Arc<AppState>>,
266 Path((credential_id, chat_id)): Path<(String, String)>,
267 headers: axum::http::HeaderMap,
268 ws: WebSocketUpgrade,
269) -> Result<impl IntoResponse, AppError> {
270 let config = state.config.read().await;
271
272 let credential = config
274 .credentials
275 .get(&credential_id)
276 .ok_or_else(|| AppError::CredentialNotFound(credential_id.clone()))?;
277
278 if credential.adapter != "generic" {
280 return Err(AppError::Internal(format!(
281 "Credential {} is not a generic adapter credential",
282 credential_id
283 )));
284 }
285
286 let expected_token = &credential.token;
288 let auth_header = headers
289 .get(header::AUTHORIZATION)
290 .and_then(|v| v.to_str().ok());
291
292 match auth_header {
293 Some(auth) if auth.starts_with("Bearer ") => {
294 let token = &auth[7..];
295 if token != expected_token {
296 return Err(AppError::Unauthorized);
297 }
298 }
299 _ => return Err(AppError::Unauthorized),
300 }
301
302 if !credential.active {
303 return Err(AppError::CredentialInactive(credential_id.clone()));
304 }
305
306 drop(config);
307
308 let ws_registry = state.ws_registry.clone();
309 let cred_id = credential_id.clone();
310 let c_id = chat_id.clone();
311
312 Ok(ws.on_upgrade(move |socket| handle_ws(socket, ws_registry, cred_id, c_id)))
313}
314
315async fn handle_ws(
316 socket: WebSocket,
317 registry: WsRegistry,
318 credential_id: String,
319 chat_id: String,
320) {
321 let (sender, mut receiver) = socket.split();
322 let sender = Arc::new(tokio::sync::Mutex::new(sender));
323
324 let mut rx = {
326 let mut reg = registry.write().await;
327 let key = (credential_id.clone(), chat_id.clone());
328
329 let tx = reg.entry(key).or_insert_with(|| {
330 let (tx, _) = broadcast::channel(100);
331 tx
332 });
333
334 tx.subscribe()
335 };
336
337 tracing::info!(
338 credential_id = %credential_id,
339 chat_id = %chat_id,
340 "WebSocket connected"
341 );
342
343 let sender_clone = sender.clone();
345 let send_task = tokio::spawn(async move {
346 while let Ok(msg) = rx.recv().await {
347 let json = serde_json::to_string(&msg).unwrap();
348 let mut s = sender_clone.lock().await;
349 if s.send(Message::Text(json.into())).await.is_err() {
350 break;
351 }
352 }
353 });
354
355 while let Some(msg) = receiver.next().await {
357 match msg {
358 Ok(Message::Close(_)) => break,
359 Ok(Message::Ping(_)) => {
360 tracing::trace!("Received ping");
362 }
363 Err(e) => {
364 tracing::debug!(error = %e, "WebSocket error");
365 break;
366 }
367 _ => {}
368 }
369 }
370
371 send_task.abort();
373
374 {
376 let mut reg = registry.write().await;
377 let key = (credential_id.clone(), chat_id.clone());
378 if let Some(tx) = reg.get(&key)
379 && tx.receiver_count() == 0
380 {
381 reg.remove(&key);
382 }
383 }
384
385 tracing::info!(
386 credential_id = %credential_id,
387 chat_id = %chat_id,
388 "WebSocket disconnected"
389 );
390}
391
392pub async fn send_to_ws(
394 registry: &WsRegistry,
395 credential_id: &str,
396 chat_id: &str,
397 message: WsOutboundMessage,
398) -> bool {
399 let reg = registry.read().await;
400 let key = (credential_id.to_string(), chat_id.to_string());
401
402 if let Some(tx) = reg.get(&key) {
403 match tx.send(message) {
404 Ok(_) => {
405 tracing::debug!(
406 credential_id = %credential_id,
407 chat_id = %chat_id,
408 "Message sent to WebSocket"
409 );
410 true
411 }
412 Err(_) => {
413 tracing::debug!(
414 credential_id = %credential_id,
415 chat_id = %chat_id,
416 "No active WebSocket subscribers"
417 );
418 false
419 }
420 }
421 } else {
422 tracing::debug!(
423 credential_id = %credential_id,
424 chat_id = %chat_id,
425 "No WebSocket connection for this chat"
426 );
427 false
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 fn make_ws_message(text: &str, message_id: &str) -> WsOutboundMessage {
436 WsOutboundMessage {
437 text: text.to_string(),
438 timestamp: chrono::Utc::now(),
439 message_id: message_id.to_string(),
440 file_urls: vec![],
441 }
442 }
443
444 #[test]
447 fn test_new_ws_registry() {
448 let registry = new_ws_registry();
449 let rt = tokio::runtime::Runtime::new().unwrap();
451 rt.block_on(async {
452 let reg = registry.read().await;
453 assert!(reg.is_empty());
454 });
455 }
456
457 #[tokio::test]
458 async fn test_ws_registry_add_channel() {
459 let registry = new_ws_registry();
460
461 {
463 let mut reg = registry.write().await;
464 let (tx, _) = broadcast::channel::<WsOutboundMessage>(100);
465 reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
466 }
467
468 {
470 let reg = registry.read().await;
471 assert!(reg.contains_key(&("cred1".to_string(), "chat1".to_string())));
472 }
473 }
474
475 #[tokio::test]
476 async fn test_ws_registry_remove_channel() {
477 let registry = new_ws_registry();
478
479 {
481 let mut reg = registry.write().await;
482 let (tx, _) = broadcast::channel::<WsOutboundMessage>(100);
483 reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
484 }
485
486 {
487 let mut reg = registry.write().await;
488 reg.remove(&("cred1".to_string(), "chat1".to_string()));
489 }
490
491 {
492 let reg = registry.read().await;
493 assert!(!reg.contains_key(&("cred1".to_string(), "chat1".to_string())));
494 }
495 }
496
497 #[tokio::test]
500 async fn test_send_to_ws_no_connection() {
501 let registry = new_ws_registry();
502
503 let message = make_ws_message("Hello", "msg_123");
504
505 let result = send_to_ws(®istry, "cred1", "chat1", message).await;
506 assert!(!result);
507 }
508
509 #[tokio::test]
510 async fn test_send_to_ws_with_subscriber() {
511 let registry = new_ws_registry();
512
513 let mut rx = {
515 let mut reg = registry.write().await;
516 let (tx, rx) = broadcast::channel::<WsOutboundMessage>(100);
517 reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
518 rx
519 };
520
521 let message = make_ws_message("Hello", "msg_123");
522
523 let result = send_to_ws(®istry, "cred1", "chat1", message).await;
524 assert!(result);
525
526 let received = rx.recv().await.unwrap();
528 assert_eq!(received.text, "Hello");
529 assert_eq!(received.message_id, "msg_123");
530 }
531
532 #[tokio::test]
533 async fn test_send_to_ws_no_subscribers() {
534 let registry = new_ws_registry();
535
536 {
538 let mut reg = registry.write().await;
539 let (tx, _rx) = broadcast::channel::<WsOutboundMessage>(100);
540 drop(_rx);
542 reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
543 }
544
545 let message = make_ws_message("Hello", "msg_456");
546
547 let result = send_to_ws(®istry, "cred1", "chat1", message).await;
548 assert!(!result);
550 }
551
552 #[test]
555 fn test_chat_request_parse() {
556 let json = r#"{
557 "chat_id": "12345",
558 "text": "Hello, world!",
559 "from": {
560 "id": "user_1",
561 "display_name": "Test User"
562 }
563 }"#;
564
565 let req: ChatRequest = serde_json::from_str(json).unwrap();
566 assert_eq!(req.chat_id, "12345");
567 assert_eq!(req.text, "Hello, world!");
568 assert_eq!(req.from.id, "user_1");
569 assert_eq!(req.from.display_name, Some("Test User".to_string()));
570 }
571
572 #[test]
573 fn test_chat_request_minimal() {
574 let json = r#"{
575 "chat_id": "12345",
576 "text": "Hello",
577 "from": {"id": "user_1"}
578 }"#;
579
580 let req: ChatRequest = serde_json::from_str(json).unwrap();
581 assert_eq!(req.chat_id, "12345");
582 assert_eq!(req.from.id, "user_1");
583 assert!(req.from.display_name.is_none());
584 }
585
586 #[test]
589 fn test_chat_user_parse() {
590 let json = r#"{"id": "user_123", "display_name": "John Doe"}"#;
591 let user: ChatUser = serde_json::from_str(json).unwrap();
592 assert_eq!(user.id, "user_123");
593 assert_eq!(user.display_name, Some("John Doe".to_string()));
594 }
595
596 #[test]
597 fn test_chat_user_minimal() {
598 let json = r#"{"id": "user_123"}"#;
599 let user: ChatUser = serde_json::from_str(json).unwrap();
600 assert_eq!(user.id, "user_123");
601 assert!(user.display_name.is_none());
602 }
603
604 #[test]
607 fn test_chat_response_serialize() {
608 let response = ChatResponse {
609 message_id: "msg_123".to_string(),
610 timestamp: chrono::Utc::now(),
611 };
612
613 let json = serde_json::to_string(&response).unwrap();
614 assert!(json.contains("\"message_id\":\"msg_123\""));
615 assert!(json.contains("\"timestamp\""));
616 }
617
618 #[tokio::test]
621 async fn test_ws_registry_multiple_chats() {
622 let registry = new_ws_registry();
623
624 let mut rx1;
626 let mut rx2;
627 {
628 let mut reg = registry.write().await;
629 let (tx1, r1) = broadcast::channel::<WsOutboundMessage>(100);
630 let (tx2, r2) = broadcast::channel::<WsOutboundMessage>(100);
631 reg.insert(("cred1".to_string(), "chat1".to_string()), tx1);
632 reg.insert(("cred1".to_string(), "chat2".to_string()), tx2);
633 rx1 = r1;
634 rx2 = r2;
635 }
636
637 let msg1 = make_ws_message("Message for chat1", "msg_1");
639 assert!(send_to_ws(®istry, "cred1", "chat1", msg1).await);
640
641 let msg2 = make_ws_message("Message for chat2", "msg_2");
643 assert!(send_to_ws(®istry, "cred1", "chat2", msg2).await);
644
645 let received1 = rx1.recv().await.unwrap();
647 assert_eq!(received1.text, "Message for chat1");
648
649 let received2 = rx2.recv().await.unwrap();
650 assert_eq!(received2.text, "Message for chat2");
651 }
652
653 #[tokio::test]
654 async fn test_ws_registry_different_credentials() {
655 let registry = new_ws_registry();
656
657 let mut rx_cred1;
659 let mut rx_cred2;
660 {
661 let mut reg = registry.write().await;
662 let (tx1, r1) = broadcast::channel::<WsOutboundMessage>(100);
663 let (tx2, r2) = broadcast::channel::<WsOutboundMessage>(100);
664 reg.insert(("cred1".to_string(), "chat".to_string()), tx1);
665 reg.insert(("cred2".to_string(), "chat".to_string()), tx2);
666 rx_cred1 = r1;
667 rx_cred2 = r2;
668 }
669
670 let msg = make_ws_message("For cred1", "msg_1");
672 assert!(send_to_ws(®istry, "cred1", "chat", msg).await);
673
674 let received = rx_cred1.recv().await.unwrap();
676 assert_eq!(received.text, "For cred1");
677
678 assert!(rx_cred2.try_recv().is_err());
680 }
681
682 #[tokio::test]
683 async fn test_send_to_ws_multiple_messages() {
684 let registry = new_ws_registry();
685
686 let mut rx = {
687 let mut reg = registry.write().await;
688 let (tx, rx) = broadcast::channel::<WsOutboundMessage>(100);
689 reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
690 rx
691 };
692
693 for i in 1..=5 {
695 let message = make_ws_message(&format!("Message {}", i), &format!("msg_{}", i));
696 let result = send_to_ws(®istry, "cred1", "chat1", message).await;
697 assert!(result);
698 }
699
700 for i in 1..=5 {
702 let received = rx.recv().await.unwrap();
703 assert_eq!(received.text, format!("Message {}", i));
704 assert_eq!(received.message_id, format!("msg_{}", i));
705 }
706 }
707
708 #[tokio::test]
709 async fn test_ws_registry_broadcast_to_multiple_subscribers() {
710 let registry = new_ws_registry();
711
712 let (mut rx1, mut rx2);
714 {
715 let mut reg = registry.write().await;
716 let (tx, r1) = broadcast::channel::<WsOutboundMessage>(100);
717 rx1 = r1;
718 rx2 = tx.subscribe();
719 reg.insert(("cred1".to_string(), "chat1".to_string()), tx);
720 }
721
722 let message = make_ws_message("Broadcast message", "msg_broadcast");
723 let result = send_to_ws(®istry, "cred1", "chat1", message).await;
724 assert!(result);
725
726 let received1 = rx1.recv().await.unwrap();
728 let received2 = rx2.recv().await.unwrap();
729
730 assert_eq!(received1.text, "Broadcast message");
731 assert_eq!(received2.text, "Broadcast message");
732 }
733
734 #[test]
735 fn test_chat_request_with_files() {
736 let json = r#"{
737 "chat_id": "123",
738 "text": "see attached",
739 "from": {"id": "u1"},
740 "files": [
741 {"url": "https://example.com/img.jpg", "filename": "img.jpg", "mime_type": "image/jpeg"}
742 ]
743 }"#;
744 let req: ChatRequest = serde_json::from_str(json).unwrap();
745 assert_eq!(req.files.len(), 1);
746 assert_eq!(req.files[0].url, "https://example.com/img.jpg");
747 assert_eq!(req.files[0].filename, "img.jpg");
748 assert!(req.files[0].auth_header.is_none());
749 }
750
751 #[test]
752 fn test_chat_request_no_files() {
753 let json = r#"{"chat_id": "123", "text": "hello", "from": {"id": "u1"}}"#;
754 let req: ChatRequest = serde_json::from_str(json).unwrap();
755 assert!(req.files.is_empty());
756 }
757}