Skip to main content

msg_gateway/
server.rs

1use axum::{
2    Json, Router,
3    body::Body,
4    extract::State,
5    http::{Request, header},
6    middleware::{self, Next},
7    response::{IntoResponse, Response},
8    routing::{get, post},
9};
10use serde_json::json;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tower_http::trace::TraceLayer;
14
15use crate::adapter::AdapterInstanceManager;
16use crate::admin;
17use crate::backend::ExternalBackendManager;
18use crate::config::Config;
19use crate::error::AppError;
20use crate::files::FileCache;
21use crate::generic::{self, WsRegistry};
22use crate::guardrail::{GuardrailEngine, GuardrailVerdict, load_rules_from_dir};
23use crate::health::HealthMonitor;
24use crate::manager::CredentialManager;
25use crate::message::WsOutboundMessage;
26
27pub struct AppState {
28    pub config: RwLock<Config>,
29    pub ws_registry: WsRegistry,
30    pub manager: Arc<CredentialManager>,
31    pub adapter_manager: Arc<AdapterInstanceManager>,
32    pub backend_manager: Arc<ExternalBackendManager>,
33    pub skip_reload_until: RwLock<Option<std::time::Instant>>,
34    pub health_monitor: HealthMonitor,
35    pub file_cache: Option<Arc<FileCache>>,
36    pub guardrail_engine: RwLock<GuardrailEngine>,
37}
38
39use std::future::Future;
40use std::pin::Pin;
41
42/// Create the server and return the state + a future to run
43pub async fn create_server(
44    config: Config,
45    manager: Arc<CredentialManager>,
46    adapter_manager: Arc<AdapterInstanceManager>,
47    backend_manager: Arc<ExternalBackendManager>,
48) -> anyhow::Result<(
49    Arc<AppState>,
50    Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>,
51)> {
52    let listen_addr = config.gateway.listen.clone();
53    let gateway_url = format!("http://{}", listen_addr);
54
55    // Default max buffer size: 1000 messages
56    let max_buffer_size = 1000;
57
58    // Initialize file cache if configured
59    let file_cache = if let Some(ref cache_config) = config.gateway.file_cache {
60        match FileCache::new(cache_config.clone(), &gateway_url).await {
61            Ok(cache) => {
62                tracing::info!(
63                    directory = %cache_config.directory,
64                    "File cache initialized"
65                );
66                Some(Arc::new(cache))
67            }
68            Err(e) => {
69                tracing::error!(error = %e, "Failed to initialize file cache");
70                None
71            }
72        }
73    } else {
74        None
75    };
76
77    let guardrail_rules = if let Some(ref dir) = config.gateway.guardrails_dir {
78        load_rules_from_dir(std::path::Path::new(dir))
79    } else {
80        vec![]
81    };
82    let guardrail_engine = GuardrailEngine::from_rules(guardrail_rules);
83
84    let state = Arc::new(AppState {
85        config: RwLock::new(config),
86        ws_registry: generic::new_ws_registry(),
87        manager,
88        adapter_manager,
89        backend_manager,
90        skip_reload_until: RwLock::new(None),
91        health_monitor: HealthMonitor::new(max_buffer_size),
92        file_cache,
93        guardrail_engine: RwLock::new(guardrail_engine),
94    });
95
96    let app = Router::new()
97        // Public health endpoint (no auth)
98        .route("/health", get(health))
99        // Send endpoint (requires send_token)
100        .route("/api/v1/send", post(send_message))
101        // Adapter inbound endpoint (from external adapters)
102        .route("/api/v1/adapter/inbound", post(adapter_inbound))
103        // File upload endpoint (requires send_token)
104        .route("/api/v1/files", post(upload_file))
105        // File serving endpoint (no auth — file IDs are unguessable UUIDs)
106        .route("/files/{file_id}", get(serve_file))
107        // Generic protocol endpoints
108        .route("/api/v1/chat/{credential_id}", post(generic::chat_inbound))
109        .route(
110            "/ws/chat/{credential_id}/{chat_id}",
111            get(generic::ws_handler),
112        )
113        // Admin routes (requires admin_token)
114        .nest("/admin", admin_routes(state.clone()))
115        .layer(TraceLayer::new_for_http())
116        .with_state(state.clone());
117
118    let listener = tokio::net::TcpListener::bind(&listen_addr).await?;
119    tracing::info!("Listening on {}", listen_addr);
120
121    let server_future = Box::pin(async move {
122        axum::serve(listener, app).await?;
123        Ok(())
124    });
125
126    Ok((state, server_future))
127}
128
129fn admin_routes(state: Arc<AppState>) -> Router<Arc<AppState>> {
130    use axum::routing::patch;
131
132    Router::new()
133        .route("/health", get(admin_health))
134        .route(
135            "/credentials",
136            get(list_credentials).post(admin::create_credential),
137        )
138        .route(
139            "/credentials/{id}",
140            get(admin::get_credential)
141                .put(admin::update_credential)
142                .delete(admin::delete_credential),
143        )
144        .route(
145            "/credentials/{id}/activate",
146            patch(admin::activate_credential),
147        )
148        .route(
149            "/credentials/{id}/deactivate",
150            patch(admin::deactivate_credential),
151        )
152        .layer(middleware::from_fn_with_state(state, admin_auth_middleware))
153}
154
155// Middleware for admin authentication
156async fn admin_auth_middleware(
157    State(state): State<Arc<AppState>>,
158    request: Request<Body>,
159    next: Next,
160) -> Result<Response, AppError> {
161    let config = state.config.read().await;
162    let expected_token = &config.gateway.admin_token;
163
164    let auth_header = request
165        .headers()
166        .get(header::AUTHORIZATION)
167        .and_then(|v| v.to_str().ok());
168
169    match auth_header {
170        Some(auth) if auth.starts_with("Bearer ") => {
171            let token = &auth[7..];
172            if token == expected_token {
173                drop(config);
174                Ok(next.run(request).await)
175            } else {
176                Err(AppError::Unauthorized)
177            }
178        }
179        _ => Err(AppError::Unauthorized),
180    }
181}
182
183// Public health check
184async fn health() -> impl IntoResponse {
185    Json(json!({
186        "status": "ok"
187    }))
188}
189
190// Admin health check with more details
191async fn admin_health(State(state): State<Arc<AppState>>) -> impl IntoResponse {
192    let config = state.config.read().await;
193    let credential_count = config.credentials.len();
194    let active_count = config.credentials.values().filter(|c| c.active).count();
195    drop(config);
196
197    // Get instance statuses
198    let instance_statuses = state.manager.registry.get_all_status().await;
199    let running_count = instance_statuses
200        .values()
201        .filter(|(_, status)| *status == crate::manager::InstanceStatus::Running)
202        .count();
203
204    // Get adapter health statuses
205    let adapter_health = state.adapter_manager.get_all_health().await;
206    let adapters: Vec<_> = adapter_health
207        .iter()
208        .map(|(cred_id, (adapter_name, health, failures))| {
209            json!({
210                "credential_id": cred_id,
211                "adapter": adapter_name,
212                "health": format!("{:?}", health),
213                "consecutive_failures": failures
214            })
215        })
216        .collect();
217
218    // Get health monitor status
219    let health_state = state.health_monitor.get_state().await;
220    let buffer_size = state.health_monitor.buffer_size().await;
221    let last_healthy = state
222        .health_monitor
223        .last_healthy_ago()
224        .await
225        .map(|d| format!("{:.1}s ago", d.as_secs_f64()));
226
227    Json(json!({
228        "status": "ok",
229        "credentials": {
230            "total": credential_count,
231            "active": active_count,
232            "running_tasks": running_count
233        },
234        "adapters": adapters,
235        "target_server": {
236            "state": health_state.to_string(),
237            "last_healthy": last_healthy,
238            "buffered_messages": buffer_size
239        }
240    }))
241}
242
243// List all credentials (tokens redacted)
244async fn list_credentials(State(state): State<Arc<AppState>>) -> impl IntoResponse {
245    let config = state.config.read().await;
246
247    let credentials: Vec<_> = config
248        .credentials
249        .iter()
250        .map(|(id, cred)| {
251            json!({
252                "id": id,
253                "adapter": cred.adapter,
254                "active": cred.active,
255                "emergency": cred.emergency,
256                "route": cred.route
257            })
258        })
259        .collect();
260
261    Json(json!({
262        "credentials": credentials
263    }))
264}
265
266/// File attachment in send request
267#[derive(Debug, serde::Deserialize)]
268struct SendFileAttachment {
269    /// URL to download the file from
270    url: String,
271    /// Original filename
272    filename: String,
273    /// MIME type
274    mime_type: String,
275    /// Optional auth header for downloading
276    #[serde(default)]
277    auth_header: Option<String>,
278}
279
280// Send message endpoint (Pipelit → Gateway → Protocol)
281async fn send_message(
282    State(state): State<Arc<AppState>>,
283    headers: axum::http::HeaderMap,
284    Json(payload): Json<serde_json::Value>,
285) -> Result<impl IntoResponse, AppError> {
286    // Verify send token
287    let config = state.config.read().await;
288    verify_send_token(&headers, &config.auth.send_token)?;
289
290    // Extract fields from payload
291    let credential_id = payload
292        .get("credential_id")
293        .and_then(|v| v.as_str())
294        .ok_or_else(|| AppError::Internal("Missing credential_id".to_string()))?;
295
296    let chat_id = payload
297        .get("chat_id")
298        .and_then(|v| v.as_str())
299        .ok_or_else(|| AppError::Internal("Missing chat_id".to_string()))?;
300
301    let text = payload.get("text").and_then(|v| v.as_str()).unwrap_or(""); // Text is optional when sending file
302
303    // Parse optional extra_data (pass through transparently)
304    let extra_data: Option<serde_json::Value> = payload.get("extra_data").cloned();
305
306    // Parse file_ids (v0.2+) — references to files already uploaded via POST /api/v1/files
307    let file_ids: Vec<String> = payload
308        .get("file_ids")
309        .and_then(|v| v.as_array())
310        .map(|arr| {
311            arr.iter()
312                .filter_map(|v| v.as_str().map(String::from))
313                .collect()
314        })
315        .unwrap_or_default();
316
317    // Parse optional file attachment (v0.1 compat — single file download)
318    let file_attachment: Option<SendFileAttachment> = payload
319        .get("file")
320        .and_then(|v| serde_json::from_value(v.clone()).ok());
321
322    // Check credential exists and is active
323    let credential = config
324        .credentials
325        .get(credential_id)
326        .ok_or_else(|| AppError::CredentialNotFound(credential_id.to_string()))?;
327
328    if !credential.active {
329        return Err(AppError::CredentialInactive(credential_id.to_string()));
330    }
331
332    let adapter = credential.adapter.clone();
333    drop(config);
334
335    let message_id = format!("{}_{}", adapter, uuid::Uuid::new_v4());
336    let timestamp = chrono::Utc::now();
337
338    // Resolve file_ids to local file paths
339    let mut file_paths: Vec<String> = Vec::new();
340
341    for file_id in &file_ids {
342        if let Some(ref file_cache) = state.file_cache {
343            if let Some(path) = file_cache.get_file_path(file_id).await {
344                file_paths.push(path.to_string_lossy().to_string());
345            } else {
346                return Err(AppError::NotFound(format!("File not found: {}", file_id)));
347            }
348        } else {
349            return Err(AppError::Internal("File cache not configured".to_string()));
350        }
351    }
352
353    // Handle legacy file attachment (v0.1 compat): download and cache
354    let (file_path, legacy_file_id): (Option<String>, Option<String>) =
355        if let Some(file) = file_attachment {
356            if let Some(ref file_cache) = state.file_cache {
357                match file_cache
358                    .download_and_cache(
359                        &file.url,
360                        file.auth_header.as_deref(),
361                        &file.filename,
362                        &file.mime_type,
363                    )
364                    .await
365                {
366                    Ok(cached) => {
367                        tracing::info!(
368                            file_id = %cached.file_id,
369                            filename = %file.filename,
370                            "Outbound file cached"
371                        );
372                        let path = cached.path.to_string_lossy().to_string();
373                        let id = cached.file_id.clone();
374                        (Some(path), Some(id))
375                    }
376                    Err(e) => {
377                        tracing::error!(
378                            error = %e,
379                            filename = %file.filename,
380                            "Failed to cache outbound file"
381                        );
382                        return Err(AppError::Internal(format!(
383                            "Failed to download file: {}",
384                            e
385                        )));
386                    }
387                }
388            } else {
389                tracing::warn!("File attachment in send request but file cache not configured");
390                return Err(AppError::Internal("File cache not configured".to_string()));
391            }
392        } else {
393            (None, None)
394        };
395
396    // Route to appropriate adapter
397    if adapter == "generic" {
398        let mut file_urls: Vec<String> = vec![];
399        if let Some(ref fc) = state.file_cache {
400            for fid in &file_ids {
401                file_urls.push(fc.get_download_url(fid));
402            }
403            if let Some(ref fid) = legacy_file_id {
404                file_urls.push(fc.get_download_url(fid));
405            }
406        }
407
408        // Built-in generic adapter: send via WebSocket
409        let ws_msg = WsOutboundMessage {
410            text: text.to_string(),
411            timestamp,
412            message_id: message_id.clone(),
413            file_urls,
414        };
415
416        let sent = generic::send_to_ws(&state.ws_registry, credential_id, chat_id, ws_msg).await;
417
418        if sent {
419            tracing::info!(
420                credential_id = credential_id,
421                chat_id = chat_id,
422                "Message sent via WebSocket"
423            );
424        } else {
425            tracing::warn!(
426                credential_id = credential_id,
427                chat_id = chat_id,
428                "No WebSocket connection, message dropped"
429            );
430        }
431    } else {
432        // External adapter: POST to adapter's /send endpoint
433        let port = state.adapter_manager.get_port(credential_id).await;
434
435        match port {
436            Some(port) if port > 0 => {
437                let send_req = crate::adapter::AdapterSendRequest {
438                    chat_id: chat_id.to_string(),
439                    text: text.to_string(),
440                    reply_to_message_id: payload
441                        .get("reply_to_message_id")
442                        .and_then(|v| v.as_str())
443                        .map(String::from),
444                    file_path: file_path.clone(),
445                    file_paths: file_paths.clone(),
446                    extra_data: extra_data.clone(),
447                };
448
449                let client = reqwest::Client::new();
450                let url = format!("http://127.0.0.1:{}/send", port);
451
452                match client.post(&url).json(&send_req).send().await {
453                    Ok(resp) if resp.status().is_success() => {
454                        match resp.json::<crate::adapter::AdapterSendResponse>().await {
455                            Ok(adapter_resp) => {
456                                tracing::info!(
457                                    credential_id = credential_id,
458                                    adapter = adapter,
459                                    protocol_message_id = %adapter_resp.protocol_message_id,
460                                    "Message sent via adapter"
461                                );
462                                return Ok(Json(json!({
463                                    "status": "sent",
464                                    "protocol_message_id": adapter_resp.protocol_message_id,
465                                    "timestamp": timestamp.to_rfc3339()
466                                })));
467                            }
468                            Err(e) => {
469                                tracing::error!(
470                                    credential_id = credential_id,
471                                    error = %e,
472                                    "Failed to parse adapter response"
473                                );
474                            }
475                        }
476                    }
477                    Ok(resp) => {
478                        tracing::error!(
479                            credential_id = credential_id,
480                            status = %resp.status(),
481                            "Adapter returned error"
482                        );
483                    }
484                    Err(e) => {
485                        tracing::error!(
486                            credential_id = credential_id,
487                            error = %e,
488                            "Failed to send to adapter"
489                        );
490                    }
491                }
492            }
493            _ => {
494                tracing::warn!(
495                    credential_id = credential_id,
496                    adapter = adapter,
497                    "No adapter instance running for credential"
498                );
499            }
500        }
501    }
502
503    Ok(Json(json!({
504        "status": "sent",
505        "protocol_message_id": message_id,
506        "timestamp": timestamp.to_rfc3339()
507    })))
508}
509
510// Inbound message from external adapter
511async fn adapter_inbound(
512    State(state): State<Arc<AppState>>,
513    Json(payload): Json<crate::adapter::AdapterInboundRequest>,
514) -> Result<impl IntoResponse, AppError> {
515    // Look up credential by instance_id
516    let credential_id = state
517        .adapter_manager
518        .get_credential_id(&payload.instance_id)
519        .await
520        .ok_or_else(|| {
521            tracing::warn!(
522                instance_id = %payload.instance_id,
523                "Could not find credential for instance"
524            );
525            AppError::Internal(format!("Unknown instance: {}", payload.instance_id))
526        })?;
527
528    let config = state.config.read().await;
529
530    let credential = config
531        .credentials
532        .get(&credential_id)
533        .ok_or_else(|| AppError::CredentialNotFound(credential_id.clone()))?;
534
535    if !credential.active {
536        return Err(AppError::CredentialInactive(credential_id.clone()));
537    }
538
539    let route = credential.route.clone();
540    let adapter = credential.adapter.clone();
541
542    let backend_name = crate::backend::resolve_backend_name(credential, &config.gateway)
543        .ok_or_else(|| {
544            AppError::Internal("No backend configured for this credential".to_string())
545        })?;
546    let backend_cfg = config.backends.get(&backend_name).ok_or_else(|| {
547        AppError::Internal(format!("Backend '{}' not found in config", backend_name))
548    })?;
549    let gateway_ctx = crate::backend::GatewayContext {
550        gateway_url: format!("http://{}", config.gateway.listen),
551        send_token: config.auth.send_token.clone(),
552    };
553    let backend_adapter = crate::backend::create_adapter(
554        backend_cfg,
555        Some(&gateway_ctx),
556        credential.config.as_ref().or(backend_cfg.config.as_ref()),
557    )
558    .map_err(|e| AppError::Internal(format!("Failed to create backend adapter: {}", e)))?;
559    drop(config);
560
561    // Build normalized inbound message
562    let timestamp = payload
563        .timestamp
564        .as_ref()
565        .and_then(|t| chrono::DateTime::parse_from_rfc3339(t).ok())
566        .map(|dt| dt.with_timezone(&chrono::Utc))
567        .unwrap_or_else(chrono::Utc::now);
568
569    // Collect all file infos: merge v0.2 `files[]` with v0.1 compat `file`
570    let mut all_files: Vec<&crate::adapter::AdapterFileInfo> = payload.files.iter().collect();
571    if let Some(ref legacy_file) = payload.file
572        && all_files.is_empty()
573    {
574        // Only use legacy `file` if `files[]` is not provided
575        all_files.push(legacy_file);
576    }
577
578    // Handle file attachments
579    let mut attachments = vec![];
580    for file_info in &all_files {
581        if let Some(ref file_cache) = state.file_cache {
582            match file_cache
583                .download_and_cache(
584                    &file_info.url,
585                    file_info.auth_header.as_deref(),
586                    &file_info.filename,
587                    &file_info.mime_type,
588                )
589                .await
590            {
591                Ok(cached) => {
592                    attachments.push(crate::message::Attachment {
593                        filename: cached.filename,
594                        mime_type: cached.mime_type,
595                        size_bytes: cached.size_bytes,
596                        download_url: file_cache.get_download_url(&cached.file_id),
597                    });
598                    tracing::info!(
599                        file_id = %cached.file_id,
600                        filename = %file_info.filename,
601                        "File attachment cached"
602                    );
603                }
604                Err(e) => {
605                    tracing::warn!(
606                        error = %e,
607                        filename = %file_info.filename,
608                        "Failed to cache file attachment"
609                    );
610                    // Include a stub attachment with error info
611                    attachments.push(crate::message::Attachment {
612                        filename: file_info.filename.clone(),
613                        mime_type: file_info.mime_type.clone(),
614                        size_bytes: 0,
615                        download_url: format!("error: {}", e),
616                    });
617                }
618            }
619        } else {
620            tracing::warn!("File attachment received but file cache not configured");
621        }
622    }
623
624    let inbound = crate::message::InboundMessage {
625        route,
626        credential_id: credential_id.clone(),
627        source: crate::message::MessageSource {
628            protocol: adapter.clone(),
629            chat_id: payload.chat_id.clone(),
630            message_id: payload.message_id.clone(),
631            reply_to_message_id: payload.reply_to_message_id,
632            from: crate::message::UserInfo {
633                id: payload.from.id,
634                username: payload.from.username,
635                display_name: payload.from.display_name,
636            },
637        },
638        text: payload.text,
639        attachments,
640        timestamp,
641        extra_data: payload.extra_data,
642    };
643
644    let verdict = {
645        let engine = state.guardrail_engine.read().await;
646        engine.evaluate_inbound(&inbound)
647    };
648    match verdict {
649        GuardrailVerdict::Block { reject_message, .. } => {
650            return Err(AppError::Forbidden(reject_message));
651        }
652        GuardrailVerdict::Allow => {}
653    }
654
655    // Check health state
656    let health_state = state.health_monitor.get_state().await;
657    if health_state == crate::health::HealthState::Down {
658        state.health_monitor.buffer_message(inbound).await;
659        tracing::info!(
660            credential_id = %credential_id,
661            instance_id = %payload.instance_id,
662            "Message buffered (target server down)"
663        );
664    } else {
665        // Forward to backend
666        let instance_id = payload.instance_id.clone();
667        let cred_id = credential_id.clone();
668
669        tokio::spawn(async move {
670            match backend_adapter.send_message(&inbound).await {
671                Ok(()) => {
672                    tracing::debug!(
673                        credential_id = %cred_id,
674                        instance_id = %instance_id,
675                        "Message forwarded to backend"
676                    );
677                }
678                Err(e) => {
679                    tracing::error!(
680                        credential_id = %cred_id,
681                        instance_id = %instance_id,
682                        error = %e,
683                        "Failed to forward message to backend"
684                    );
685                }
686            }
687        });
688    }
689
690    Ok((
691        axum::http::StatusCode::ACCEPTED,
692        Json(json!({
693            "status": "accepted"
694        })),
695    ))
696}
697
698/// Helper to verify send_token from Authorization header
699fn verify_send_token(
700    headers: &axum::http::HeaderMap,
701    expected_token: &str,
702) -> Result<(), AppError> {
703    let auth_header = headers
704        .get(header::AUTHORIZATION)
705        .and_then(|v| v.to_str().ok());
706
707    match auth_header {
708        Some(auth) if auth.starts_with("Bearer ") => {
709            let token = &auth[7..];
710            if token != expected_token {
711                return Err(AppError::Unauthorized);
712            }
713        }
714        _ => return Err(AppError::Unauthorized),
715    }
716    Ok(())
717}
718
719// Upload file endpoint (POST /api/v1/files)
720async fn upload_file(
721    State(state): State<Arc<AppState>>,
722    headers: axum::http::HeaderMap,
723    mut multipart: axum::extract::Multipart,
724) -> Result<impl IntoResponse, AppError> {
725    // Verify send token
726    let config = state.config.read().await;
727    verify_send_token(&headers, &config.auth.send_token)?;
728    drop(config);
729
730    // Get file cache
731    let file_cache = state
732        .file_cache
733        .as_ref()
734        .ok_or_else(|| AppError::Internal("File cache not configured".to_string()))?;
735
736    // Parse multipart fields
737    let mut file_data: Option<Vec<u8>> = None;
738    let mut filename: Option<String> = None;
739    let mut mime_type: Option<String> = None;
740    let mut multipart_filename: Option<String> = None;
741
742    while let Some(field) = multipart
743        .next_field()
744        .await
745        .map_err(|e| AppError::BadRequest(format!("Failed to read multipart field: {}", e)))?
746    {
747        let field_name = field.name().unwrap_or("").to_string();
748
749        match field_name.as_str() {
750            "file" => {
751                // Capture filename from multipart Content-Disposition if available
752                if let Some(fname) = field.file_name() {
753                    multipart_filename = Some(fname.to_string());
754                }
755                let bytes = field.bytes().await.map_err(|e| {
756                    AppError::BadRequest(format!("Failed to read file data: {}", e))
757                })?;
758                file_data = Some(bytes.to_vec());
759            }
760            "filename" => {
761                let text = field
762                    .text()
763                    .await
764                    .map_err(|e| AppError::BadRequest(format!("Failed to read filename: {}", e)))?;
765                filename = Some(text);
766            }
767            "mime_type" => {
768                let text = field.text().await.map_err(|e| {
769                    AppError::BadRequest(format!("Failed to read mime_type: {}", e))
770                })?;
771                mime_type = Some(text);
772            }
773            _ => {
774                // Skip unknown fields
775            }
776        }
777    }
778
779    // Validate required fields
780    let data = file_data.ok_or_else(|| AppError::BadRequest("Missing 'file' field".to_string()))?;
781
782    if data.is_empty() {
783        return Err(AppError::BadRequest("File data is empty".to_string()));
784    }
785
786    // Use explicit filename, fall back to multipart filename
787    let filename = filename
788        .or(multipart_filename)
789        .ok_or_else(|| AppError::BadRequest("Missing 'filename' field".to_string()))?;
790
791    let mime_type = mime_type.unwrap_or_else(|| "application/octet-stream".to_string());
792
793    // Store file — map errors to proper HTTP status codes
794    let cached = file_cache
795        .store_file(data, &filename, &mime_type)
796        .await
797        .map_err(|e| {
798            let msg = e.to_string();
799            if msg.contains("too large") {
800                AppError::PayloadTooLarge(msg)
801            } else if msg.contains("MIME type") {
802                AppError::UnsupportedMediaType(msg)
803            } else {
804                e
805            }
806        })?;
807
808    let download_url = file_cache.get_download_url(&cached.file_id);
809
810    tracing::info!(
811        file_id = %cached.file_id,
812        filename = %cached.filename,
813        size = cached.size_bytes,
814        "File uploaded via API"
815    );
816
817    Ok(Json(json!({
818        "file_id": cached.file_id,
819        "filename": cached.filename,
820        "mime_type": cached.mime_type,
821        "size_bytes": cached.size_bytes,
822        "download_url": download_url
823    })))
824}
825
826// Serve cached files (no auth — file IDs are unguessable UUIDs)
827async fn serve_file(
828    State(state): State<Arc<AppState>>,
829    axum::extract::Path(file_id): axum::extract::Path<String>,
830) -> Result<impl IntoResponse, AppError> {
831    // Get file cache
832    let file_cache = state
833        .file_cache
834        .as_ref()
835        .ok_or_else(|| AppError::Internal("File cache not configured".to_string()))?;
836
837    // Get file metadata
838    let cached = file_cache
839        .get(&file_id)
840        .await
841        .ok_or_else(|| AppError::NotFound(format!("File not found: {}", file_id)))?;
842
843    // Read file content
844    let content = file_cache.read_file(&file_id).await?;
845
846    // Build response with appropriate headers
847    let content_disposition = format!(
848        "attachment; filename=\"{}\"",
849        cached.filename.replace("\"", "\\\"")
850    );
851
852    Ok((
853        [
854            (header::CONTENT_TYPE, cached.mime_type),
855            (header::CONTENT_DISPOSITION, content_disposition),
856        ],
857        content,
858    ))
859}
860
861#[cfg(test)]
862mod tests {
863    use super::*;
864
865    // ==================== SendFileAttachment Tests ====================
866
867    #[test]
868    fn test_send_file_attachment_parse() {
869        let json = r#"{
870            "url": "https://example.com/file.pdf",
871            "filename": "document.pdf",
872            "mime_type": "application/pdf"
873        }"#;
874
875        let attachment: SendFileAttachment = serde_json::from_str(json).unwrap();
876        assert_eq!(attachment.url, "https://example.com/file.pdf");
877        assert_eq!(attachment.filename, "document.pdf");
878        assert_eq!(attachment.mime_type, "application/pdf");
879        assert!(attachment.auth_header.is_none());
880    }
881
882    #[test]
883    fn test_send_file_attachment_with_auth() {
884        let json = r#"{
885            "url": "https://example.com/file.pdf",
886            "filename": "document.pdf",
887            "mime_type": "application/pdf",
888            "auth_header": "Bearer token123"
889        }"#;
890
891        let attachment: SendFileAttachment = serde_json::from_str(json).unwrap();
892        assert_eq!(attachment.auth_header, Some("Bearer token123".to_string()));
893    }
894
895    // ==================== Content-Disposition Escaping Tests ====================
896
897    #[test]
898    fn test_content_disposition_escaping() {
899        // Test filename with quotes
900        let filename = r#"file"name.pdf"#;
901        let content_disposition = format!(
902            "attachment; filename=\"{}\"",
903            filename.replace("\"", "\\\"")
904        );
905        assert_eq!(
906            content_disposition,
907            r#"attachment; filename="file\"name.pdf""#
908        );
909    }
910
911    #[test]
912    fn test_content_disposition_normal() {
913        let filename = "document.pdf";
914        let content_disposition = format!(
915            "attachment; filename=\"{}\"",
916            filename.replace("\"", "\\\"")
917        );
918        assert_eq!(
919            content_disposition,
920            r#"attachment; filename="document.pdf""#
921        );
922    }
923
924    #[test]
925    fn test_send_file_attachment_missing_optional() {
926        // auth_header is optional
927        let json = r#"{
928            "url": "https://example.com/file.txt",
929            "filename": "test.txt",
930            "mime_type": "text/plain"
931        }"#;
932
933        let attachment: SendFileAttachment = serde_json::from_str(json).unwrap();
934        assert!(attachment.auth_header.is_none());
935    }
936
937    #[test]
938    fn test_send_file_attachment_debug() {
939        let attachment = SendFileAttachment {
940            url: "https://example.com/file.pdf".to_string(),
941            filename: "doc.pdf".to_string(),
942            mime_type: "application/pdf".to_string(),
943            auth_header: None,
944        };
945
946        let debug_str = format!("{:?}", attachment);
947        assert!(debug_str.contains("SendFileAttachment"));
948        assert!(debug_str.contains("doc.pdf"));
949    }
950
951    #[test]
952    fn test_content_disposition_special_chars() {
953        // Test with special characters in filename
954        let filename = "file with spaces.pdf";
955        let content_disposition = format!(
956            "attachment; filename=\"{}\"",
957            filename.replace("\"", "\\\"")
958        );
959        assert_eq!(
960            content_disposition,
961            r#"attachment; filename="file with spaces.pdf""#
962        );
963    }
964
965    #[test]
966    fn test_content_disposition_unicode() {
967        // Test with unicode characters
968        let filename = "文档.pdf";
969        let content_disposition = format!(
970            "attachment; filename=\"{}\"",
971            filename.replace("\"", "\\\"")
972        );
973        assert!(content_disposition.contains("文档.pdf"));
974    }
975
976    #[test]
977    fn test_send_file_attachment_various_mime_types() {
978        let test_cases = vec![
979            ("image/png", "image.png"),
980            ("video/mp4", "video.mp4"),
981            ("audio/mpeg", "audio.mp3"),
982            ("application/json", "data.json"),
983            ("text/html", "page.html"),
984        ];
985
986        for (mime_type, filename) in test_cases {
987            let json = format!(
988                r#"{{"url": "https://example.com/{}", "filename": "{}", "mime_type": "{}"}}"#,
989                filename, filename, mime_type
990            );
991            let attachment: SendFileAttachment = serde_json::from_str(&json).unwrap();
992            assert_eq!(attachment.mime_type, mime_type);
993            assert_eq!(attachment.filename, filename);
994        }
995    }
996}