1#[allow(unused_imports)]
2use std::io::Read;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use std::time::Instant;
6
7use pylon_auth::SessionStore;
8use pylon_http::HttpMethod;
9use pylon_plugin::PluginRegistry;
10use pylon_policy::PolicyEngine;
11use pylon_sync::{ChangeKind, ChangeLog};
12use tiny_http::{Header, Method, Response, Server};
13
14use crate::datastore::{
15 CacheAdapter, EmailAdapter, LocalFileOps, PluginHooksAdapter, PubSubAdapter,
16 RuntimeOpenApiGenerator, ShardOpsAdapter, WsSseNotifier,
17};
18use crate::jobs::{JobQueue, JobResult, Worker};
19use crate::metrics::Metrics;
20use crate::pubsub::PubSubBroker;
21use crate::rate_limit::RateLimiter;
22use crate::rooms::RoomManager;
23use crate::scheduler::Scheduler;
24use crate::sse::SseHub;
25use crate::workflows::WorkflowEngine;
26use crate::ws::WsHub;
27use crate::Runtime;
28use pylon_plugin::builtin::ai_proxy::{AiMessage, AiProxyPlugin};
29use pylon_plugin::builtin::cache::CachePlugin;
30
31struct StreamingBody {
41 rx: std::sync::mpsc::Receiver<Vec<u8>>,
42 buf: Vec<u8>,
43 pos: usize,
44}
45
46impl StreamingBody {
47 fn new(rx: std::sync::mpsc::Receiver<Vec<u8>>) -> Self {
48 Self {
49 rx,
50 buf: Vec::new(),
51 pos: 0,
52 }
53 }
54}
55
56impl std::io::Read for StreamingBody {
57 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
58 if self.pos < self.buf.len() {
61 let remaining = &self.buf[self.pos..];
62 let n = remaining.len().min(buf.len());
63 buf[..n].copy_from_slice(&remaining[..n]);
64 self.pos += n;
65 if self.pos >= self.buf.len() {
66 self.buf.clear();
67 self.pos = 0;
68 }
69 return Ok(n);
70 }
71
72 match self.rx.recv() {
74 Ok(data) if data.is_empty() => Ok(0),
75 Ok(data) => {
76 let n = data.len().min(buf.len());
77 buf[..n].copy_from_slice(&data[..n]);
78 if n < data.len() {
79 self.buf = data;
80 self.pos = n;
81 }
82 Ok(n)
83 }
84 Err(_) => Ok(0), }
86 }
87}
88
89static SHUTDOWN: AtomicBool = AtomicBool::new(false);
91
92pub fn request_shutdown() {
98 SHUTDOWN.store(true, Ordering::SeqCst);
99 if let Some(srv) = SERVER_HANDLE.get() {
102 srv.unblock();
103 }
104}
105
106static SERVER_HANDLE: std::sync::OnceLock<Arc<Server>> = std::sync::OnceLock::new();
109
110fn resolve_client_ip(request: &tiny_http::Request, trust_proxy_hops: usize) -> String {
126 let socket_ip = request
127 .remote_addr()
128 .map(|a| a.ip().to_string())
129 .unwrap_or_default();
130 if trust_proxy_hops == 0 {
131 return socket_ip;
132 }
133 let xff = request
136 .headers()
137 .iter()
138 .find(|h| {
139 h.field
140 .as_str()
141 .as_str()
142 .eq_ignore_ascii_case("X-Forwarded-For")
143 })
144 .map(|h| h.value.as_str().to_string());
145 let Some(xff) = xff else {
146 return socket_ip;
147 };
148 let entries: Vec<&str> = xff.split(',').map(str::trim).collect();
153 if entries.len() < trust_proxy_hops {
154 return socket_ip;
159 }
160 let candidate = entries[entries.len() - trust_proxy_hops];
161 if candidate.parse::<std::net::IpAddr>().is_ok() {
164 candidate.to_string()
165 } else {
166 socket_ip
167 }
168}
169
170fn security_headers() -> Vec<Header> {
179 vec![
180 Header::from_bytes("X-Content-Type-Options", "nosniff").unwrap(),
181 Header::from_bytes("X-Frame-Options", "DENY").unwrap(),
182 Header::from_bytes("X-XSS-Protection", "1; mode=block").unwrap(),
183 Header::from_bytes("Referrer-Policy", "strict-origin-when-cross-origin").unwrap(),
187 Header::from_bytes(
191 "Permissions-Policy",
192 "accelerometer=(), camera=(), geolocation=(), gyroscope=(), microphone=(), payment=(), usb=()",
193 )
194 .unwrap(),
195 ]
196}
197
198fn with_security_headers<R: std::io::Read>(response: Response<R>) -> Response<R> {
200 let mut resp = response;
201 for header in security_headers() {
202 resp = resp.with_header(header);
203 }
204 resp
205}
206
207pub fn start(runtime: Arc<Runtime>, port: u16) -> Result<(), String> {
209 start_with_plugins(runtime, port, None)
210}
211
212pub fn start_with_plugins(
214 runtime: Arc<Runtime>,
215 port: u16,
216 plugins: Option<Arc<PluginRegistry>>,
217) -> Result<(), String> {
218 start_server(runtime, port, plugins, None)
219}
220
221pub fn start_with_shards(
224 runtime: Arc<Runtime>,
225 port: u16,
226 plugins: Option<Arc<PluginRegistry>>,
227 shard_registry: Arc<dyn pylon_realtime::DynShardRegistry>,
228) -> Result<(), String> {
229 start_server(runtime, port, plugins, Some(shard_registry))
230}
231
232fn start_server(
233 runtime: Arc<Runtime>,
234 port: u16,
235 plugins: Option<Arc<PluginRegistry>>,
236 shard_registry: Option<Arc<dyn pylon_realtime::DynShardRegistry>>,
237) -> Result<(), String> {
238 pylon_observability::run_tracing_hook();
243
244 let addr = format!("0.0.0.0:{port}");
245 let server = Server::http(&addr).map_err(|e| format!("Failed to start server: {e}"))?;
246 let server = Arc::new(server);
247
248 let _ = SERVER_HANDLE.set(Arc::clone(&server));
250
251 let session_lifetime = runtime.manifest().auth.session.expires_in;
252 let auth_stores = build_auth_stores(runtime.db_path().as_deref(), session_lifetime);
253 let session_store = auth_stores.session_store;
254 let magic_codes = auth_stores.magic_codes;
255 let oauth_state = auth_stores.oauth_state;
256 let account_store = auth_stores.account_store;
257 let api_keys = auth_stores.api_keys;
258 let orgs = auth_stores.orgs;
259 let siwe = auth_stores.siwe;
260 let phone_codes = auth_stores.phone_codes;
261 let passkeys = auth_stores.passkeys;
262 let policy_engine = Arc::new(PolicyEngine::from_manifest(runtime.manifest()));
263 let change_log = Arc::new(ChangeLog::new());
264
265 for entity in runtime.manifest().entities.iter() {
273 match runtime.list(&entity.name) {
274 Ok(rows) => {
275 for row in rows {
276 if let Some(id) = row.get("id").and_then(|v| v.as_str()) {
277 change_log.append(&entity.name, id, ChangeKind::Insert, Some(row.clone()));
278 }
279 }
280 }
281 Err(_) => {
282 }
284 }
285 }
286 let ws_hub = WsHub::new();
287 let sse_hub = SseHub::new();
288 let is_dev_early = std::env::var("PYLON_DEV_MODE")
301 .map(|v| v == "1" || v == "true")
302 .unwrap_or(true);
303 let plugin_rl_max: u32 = if is_dev_early { 100_000 } else { 100 };
304 let plugin_reg: Arc<PluginRegistry> = plugins.unwrap_or_else(|| {
305 let mut reg = PluginRegistry::new(runtime.manifest().clone());
306 reg.register(Arc::new(
307 pylon_plugin::builtin::rate_limit::RateLimitPlugin::new(
308 plugin_rl_max,
309 std::time::Duration::from_secs(60),
310 ),
311 ));
312 reg.register(Arc::new(
317 pylon_plugin::builtin::tenant_scope::TenantScopePlugin::from_manifest(
318 runtime.manifest(),
319 ),
320 ));
321 Arc::new(reg)
322 });
323 let room_mgr = Arc::new(RoomManager::new(120)); let ws_port = port + 1;
325 let sse_port = port + 2;
326
327 let start_time = Instant::now();
329
330 let metrics = Arc::new(Metrics::new());
331
332 let cache = Arc::new(CachePlugin::new(100_000));
334 let pubsub_broker = Arc::new(PubSubBroker::new(100));
335
336 let job_queue = Arc::new(JobQueue::new(1000));
338
339 let jobs_in_memory = std::env::var("PYLON_JOBS_IN_MEMORY")
345 .map(|v| v == "1" || v == "true")
346 .unwrap_or(false);
347 if !jobs_in_memory {
348 let jobs_db_path = std::env::var("PYLON_JOBS_DB").ok().unwrap_or_else(|| {
349 runtime
350 .db_path()
351 .map(|p| format!("{p}.jobs.db"))
352 .unwrap_or_else(|| "pylon.jobs.db".into())
353 });
354 match crate::job_store::JobStore::open(&jobs_db_path) {
355 Ok(store) => {
356 let store = Arc::new(store);
357 let restored = job_queue.restore_from(&store);
358 if restored > 0 {
359 tracing::info!("[jobs] Restored {restored} pending job(s) from {jobs_db_path}");
360 }
361 job_queue.attach_store(store);
362 }
363 Err(e) => {
364 tracing::warn!(
365 "[jobs] Could not open job store at {jobs_db_path}: {e} — running without persistence"
366 );
367 }
368 }
369 }
370
371 {
373 let cache_ref = Arc::clone(&cache);
374 job_queue.register(
375 "pylon.cache.cleanup",
376 Arc::new(move |_job| {
377 cache_ref.cleanup_expired();
378 JobResult::Success
379 }),
380 );
381 let rooms_ref = Arc::clone(&room_mgr);
382 job_queue.register(
383 "pylon.rooms.cleanup",
384 Arc::new(move |_job| {
385 rooms_ref.cleanup_idle();
386 JobResult::Success
387 }),
388 );
389 }
390
391 let scheduler = Arc::new(Scheduler::new(Arc::clone(&job_queue)));
392 let _ = scheduler.schedule(
394 "pylon.cache.cleanup",
395 "*/10 * * * *",
396 Arc::new(|_| JobResult::Success),
397 );
398 let _ = scheduler.schedule(
399 "pylon.rooms.cleanup",
400 "*/5 * * * *",
401 Arc::new(|_| JobResult::Success),
402 );
403
404 let _worker_handles: Vec<_> = (0..2)
406 .map(|i| {
407 let w = Worker::new(Arc::clone(&job_queue), &format!("worker-{i}"));
408 w.start()
409 })
410 .collect();
411
412 let _scheduler_handle = Arc::clone(&scheduler).start();
414
415 let wf_runner_url = std::env::var("PYLON_WORKFLOW_RUNNER_URL")
417 .unwrap_or_else(|_| "http://127.0.0.1:9876/run".to_string());
418 let workflow_engine = Arc::new(WorkflowEngine::new(&wf_runner_url, 10_000));
419
420 let default_rl_max = if is_dev_early { 100_000 } else { 600 };
431 let rl_max: u32 = std::env::var("PYLON_RATE_LIMIT_MAX")
432 .ok()
433 .and_then(|v| v.parse().ok())
434 .unwrap_or(default_rl_max);
435 let rl_window: u64 = std::env::var("PYLON_RATE_LIMIT_WINDOW")
436 .ok()
437 .and_then(|v| v.parse().ok())
438 .unwrap_or(60);
439 let rate_limiter = Arc::new(RateLimiter::new(rl_max, rl_window));
440
441 let fn_rl_max: u32 = std::env::var("PYLON_FN_RATE_LIMIT_MAX")
445 .ok()
446 .and_then(|v| v.parse().ok())
447 .unwrap_or(30);
448 let fn_rl_window: u64 = std::env::var("PYLON_FN_RATE_LIMIT_WINDOW")
449 .ok()
450 .and_then(|v| v.parse().ok())
451 .unwrap_or(60);
452 let fn_rate_limiter = Arc::new(RateLimiter::new(fn_rl_max, fn_rl_window));
453
454 let fn_notifier: Arc<dyn pylon_router::ChangeNotifier> =
462 Arc::new(crate::datastore::WsSseNotifier {
463 ws: Arc::clone(&ws_hub),
464 sse: Arc::clone(&sse_hub),
465 });
466 let fn_ops_maybe = crate::datastore::try_spawn_functions(
467 Arc::clone(&runtime),
468 Arc::clone(&job_queue),
469 Arc::clone(&fn_rate_limiter),
470 Arc::clone(&change_log),
471 fn_notifier,
472 );
473
474 let is_dev = std::env::var("PYLON_DEV_MODE")
482 .map(|v| v == "1" || v == "true")
483 .unwrap_or(false);
484
485 let cors_origin = match std::env::var("PYLON_CORS_ORIGIN") {
492 Ok(v) => v,
493 Err(_) if is_dev => "*".to_string(),
494 Err(_) => {
495 return Err(
496 "PYLON_CORS_ORIGIN must be set in production (non-dev mode). \
497 Set it to your frontend's origin, or set PYLON_DEV_MODE=true \
498 for local development."
499 .into(),
500 );
501 }
502 };
503 if !is_dev && cors_origin == "*" {
504 return Err("PYLON_CORS_ORIGIN=\"*\" is refused in production mode. \
505 Set it to an explicit origin (https://app.example.com)."
506 .into());
507 }
508 let allow_credentials = cors_origin != "*";
515 if Header::from_bytes(
519 "Access-Control-Allow-Origin",
520 cors_origin.as_bytes().to_vec(),
521 )
522 .is_err()
523 {
524 return Err(format!(
525 "PYLON_CORS_ORIGIN={cors_origin:?} contains bytes that are not a valid HTTP header value"
526 ));
527 }
528
529 let admin_token: Option<String> = std::env::var("PYLON_ADMIN_TOKEN").ok();
531
532 let trust_proxy_hops: usize = std::env::var("PYLON_TRUST_PROXY_HOPS")
542 .ok()
543 .and_then(|v| v.parse().ok())
544 .unwrap_or(0);
545
546 let cookie_config = Arc::new({
553 let app_name = runtime.manifest().name.as_str();
554 pylon_auth::CookieConfig::from_env(&pylon_auth::CookieConfig::default_name_for(app_name))
555 });
556
557 let csrf_origins: Vec<String> = match std::env::var("PYLON_CSRF_ORIGINS") {
567 Ok(v) => v
568 .split(',')
569 .map(|s| s.trim().to_string())
570 .filter(|s| !s.is_empty())
571 .collect(),
572 Err(_) => {
573 if is_dev {
574 vec!["*".to_string()]
575 } else if cors_origin != "*" {
576 vec![cors_origin.clone()]
577 } else {
578 vec![]
580 }
581 }
582 };
583 let csrf = Arc::new(pylon_plugin::builtin::csrf::CsrfPlugin::new(csrf_origins));
584
585 let manifest_trusted: Vec<String> = runtime.manifest().auth.trusted_origins.clone();
599 let trusted_origins: Vec<String> = std::env::var("PYLON_TRUSTED_ORIGINS")
600 .map(|v| {
601 v.split(',')
602 .map(|s| s.trim().to_string())
603 .filter(|s| !s.is_empty())
604 .collect()
605 })
606 .unwrap_or_else(|_| {
607 if is_dev_early {
614 vec![
615 "http://localhost:3000".to_string(),
616 "http://localhost:4321".to_string(),
617 "http://localhost:5173".to_string(),
618 "http://127.0.0.1:3000".to_string(),
619 ]
620 } else {
621 Vec::new()
622 }
623 });
624 let mut combined: Vec<String> = trusted_origins;
626 for m in manifest_trusted {
627 if !m.is_empty() && !combined.contains(&m) {
628 combined.push(m);
629 }
630 }
631 let trusted_origins = Arc::new(combined);
632
633 {
647 let hub = Arc::clone(&ws_hub);
648 let sessions = Arc::clone(&session_store);
649 let runtime_for_fetcher = Arc::clone(&runtime);
650 let pe_for_fetcher = Arc::clone(&policy_engine);
651 let fetcher: crate::ws::SnapshotFetcher = Arc::new(move |auth_ctx, entity, row_id| {
652 use pylon_http::DataStore;
653 let row = match runtime_for_fetcher.get_by_id(entity, row_id) {
658 Ok(Some(v)) => v,
659 _ => return None,
660 };
661 if !matches!(
662 pe_for_fetcher.check_entity_read(entity, auth_ctx, Some(&row)),
663 pylon_policy::PolicyResult::Allowed
664 ) {
665 return None;
666 }
667 let snap = match runtime_for_fetcher.crdt_snapshot(entity, row_id) {
668 Ok(Some(bytes)) => bytes,
669 _ => return None,
670 };
671 pylon_router::encode_crdt_frame(
672 pylon_router::CRDT_FRAME_SNAPSHOT,
673 entity,
674 row_id,
675 &snap,
676 )
677 .ok()
678 });
679 std::thread::spawn(move || {
680 crate::ws::start_ws_server(hub, sessions, ws_port, Some(fetcher));
681 });
682 }
683
684 {
686 let hub = Arc::clone(&sse_hub);
687 std::thread::spawn(move || {
688 crate::sse::start_sse_server(hub, sse_port);
689 });
690 }
691
692 let shard_ws_port = port + 3;
694 if let Some(reg) = shard_registry.clone() {
695 let sessions = Arc::clone(&session_store);
696 std::thread::spawn(move || {
697 crate::shard_ws::start_shard_ws_server(reg, sessions, shard_ws_port);
698 });
699 }
700
701 tracing::warn!("pylon dev server listening on http://localhost:{port}");
702 tracing::info!(" WebSocket: ws://localhost:{ws_port}");
703 tracing::info!(" Studio: http://localhost:{port}/studio");
704 tracing::info!(" API: http://localhost:{port}/api/entities/<entity>");
705 tracing::info!(" Auth: http://localhost:{port}/api/auth/session");
706
707 loop {
711 if SHUTDOWN.load(Ordering::Relaxed) {
712 break;
713 }
714
715 let mut request = match server.recv() {
716 Ok(rq) => rq,
717 Err(_) => {
718 break;
720 }
721 };
722
723 if SHUTDOWN.load(Ordering::Relaxed) {
724 break;
725 }
726
727 let rt = Arc::clone(&runtime);
728 let ss = Arc::clone(&session_store);
729 let pe = Arc::clone(&policy_engine);
730 let cl = Arc::clone(&change_log);
731 let wh = Arc::clone(&ws_hub);
732 let sh = Arc::clone(&sse_hub);
733 let mc = Arc::clone(&magic_codes);
734 let pr = Arc::clone(&plugin_reg);
735 let rm = Arc::clone(&room_mgr);
736 let mt = Arc::clone(&metrics);
737 let os = Arc::clone(&oauth_state);
738 let acc = Arc::clone(&account_store);
739 let ak = Arc::clone(&api_keys);
740 let og = Arc::clone(&orgs);
741 let sw = Arc::clone(&siwe);
742 let pcd = Arc::clone(&phone_codes);
743 let pks = Arc::clone(&passkeys);
744 let trusted_origins_ref = Arc::clone(&trusted_origins);
745 let ca = Arc::clone(&cache);
746 let ps = Arc::clone(&pubsub_broker);
747 let jq = Arc::clone(&job_queue);
748 let sc = Arc::clone(&scheduler);
749 let we = Arc::clone(&workflow_engine);
750 let fn_ops_ref = fn_ops_maybe.clone();
751 let shards_ref = shard_registry.clone();
752 let cors_origin = cors_origin.clone();
753 let cookie_config = Arc::clone(&cookie_config);
754 let allow_credentials = allow_credentials;
755 let is_dev = is_dev;
756
757 let method = request.method().clone();
758 let url = request.url().to_string();
759
760 let request_peer_ip = resolve_client_ip(&request, trust_proxy_hops);
772 let request_started_at = std::time::Instant::now();
773 if url != "/health" && url != "/metrics" {
774 tracing::info!("→ {} {} from {}", method.as_str(), url, request_peer_ip);
775 crate::metrics::set_current_request(&url, request_started_at);
779 }
780
781 if url == "/health" && method == Method::Get {
783 let uptime = start_time.elapsed().as_secs();
784 let body = serde_json::json!({
785 "status": "ok",
786 "version": "0.1.0",
787 "uptime_secs": uptime,
788 })
789 .to_string();
790
791 let response = with_security_headers(
792 Response::from_string(&body)
793 .with_status_code(200u16)
794 .with_header(Header::from_bytes("Content-Type", "application/json").unwrap())
795 .with_header(
796 Header::from_bytes(
797 "Access-Control-Allow-Origin",
798 cors_origin.as_bytes().to_vec(),
799 )
800 .unwrap(),
801 ),
802 );
803 let _ = request.respond(response);
804 continue;
805 }
806
807 if url == "/metrics" && method == Method::Get {
812 if !is_dev {
813 let admin_bytes = admin_token.as_deref().unwrap_or("").as_bytes();
814 let auth_ok = !admin_bytes.is_empty()
815 && request.headers().iter().any(|h| {
816 let name = h.field.as_str().as_str();
817 name.eq_ignore_ascii_case("Authorization")
818 && h.value
819 .as_str()
820 .strip_prefix("Bearer ")
821 .map(|t| pylon_auth::constant_time_eq(t.as_bytes(), admin_bytes))
822 .unwrap_or(false)
823 });
824 if !auth_ok {
825 let body = json_error(
826 "UNAUTHORIZED",
827 "/metrics requires admin bearer token in non-dev mode",
828 );
829 let response = with_security_headers(
830 Response::from_string(&body)
831 .with_status_code(401u16)
832 .with_header(
833 Header::from_bytes("Content-Type", "application/json").unwrap(),
834 ),
835 );
836 let _ = request.respond(response);
837 continue;
838 }
839 }
840 let prefers_prometheus = request.headers().iter().any(|h| {
841 (h.field.as_str() == "Accept" || h.field.as_str() == "accept")
842 && (h.value.as_str().contains("text/plain")
843 || h.value.as_str().contains("application/openmetrics-text"))
844 });
845 let (body, content_type) = if prefers_prometheus {
846 (mt.prometheus(), "text/plain; version=0.0.4")
847 } else {
848 (mt.snapshot().to_string(), "application/json")
849 };
850 let response = with_security_headers(
851 Response::from_string(&body)
852 .with_status_code(200u16)
853 .with_header(Header::from_bytes("Content-Type", content_type).unwrap())
854 .with_header(
855 Header::from_bytes(
856 "Access-Control-Allow-Origin",
857 cors_origin.as_bytes().to_vec(),
858 )
859 .unwrap(),
860 ),
861 );
862 let _ = request.respond(response);
863 mt.record_request("GET", 200);
864 continue;
865 }
866
867 let peer_ip = resolve_client_ip(&request, trust_proxy_hops);
873
874 let is_preflight = matches!(method, Method::Options);
880 if !is_preflight {
881 if let Err(retry_after) = rate_limiter.check(&peer_ip) {
882 let err_body = json_error(
883 "RATE_LIMITED",
884 &format!("Too many requests. Retry after {retry_after} seconds."),
885 );
886 let response = with_security_headers(
887 Response::from_string(&err_body)
888 .with_status_code(429u16)
889 .with_header(
890 Header::from_bytes("Content-Type", "application/json").unwrap(),
891 )
892 .with_header(
893 Header::from_bytes(
894 "Access-Control-Allow-Origin",
895 cors_origin.as_bytes().to_vec(),
896 )
897 .unwrap(),
898 )
899 .with_header(
900 Header::from_bytes(
901 "Access-Control-Allow-Methods",
902 "GET, POST, PATCH, DELETE, OPTIONS",
903 )
904 .unwrap(),
905 )
906 .with_header(
907 Header::from_bytes(
908 "Access-Control-Allow-Headers",
909 "Content-Type, Authorization",
910 )
911 .unwrap(),
912 )
913 .with_header(
914 Header::from_bytes(
915 "Retry-After",
916 retry_after.to_string().as_bytes().to_vec(),
917 )
918 .unwrap(),
919 ),
920 );
921 let _ = request.respond(response);
922 mt.record_request(method.as_str(), 429);
923 continue;
924 }
925 } {
941 let method_str = method.as_str();
942 let is_bearer = request.headers().iter().any(|h| {
943 (h.field.as_str() == "Authorization" || h.field.as_str() == "authorization")
944 && h.value.as_str().starts_with("Bearer ")
945 });
946 if !is_bearer && !matches!(method, Method::Get | Method::Head | Method::Options) {
951 let origin = request
952 .headers()
953 .iter()
954 .find(|h| h.field.as_str() == "Origin" || h.field.as_str() == "origin")
955 .map(|h| h.value.as_str().to_string());
956 let referer = request
957 .headers()
958 .iter()
959 .find(|h| h.field.as_str() == "Referer" || h.field.as_str() == "referer")
960 .map(|h| h.value.as_str().to_string());
961 if let Err(err) = csrf.check(method_str, origin.as_deref(), referer.as_deref()) {
962 let body = json_error(&err.code, &err.message);
963 let response = with_security_headers(
964 Response::from_string(&body)
965 .with_status_code(err.status)
966 .with_header(
967 Header::from_bytes("Content-Type", "application/json").unwrap(),
968 )
969 .with_header(
970 Header::from_bytes(
971 "Access-Control-Allow-Origin",
972 cors_origin.as_bytes().to_vec(),
973 )
974 .unwrap(),
975 ),
976 );
977 let _ = request.respond(response);
978 mt.record_request(method_str, err.status);
979 continue;
980 }
981 }
982 }
983
984 let bearer_token: Option<String> = request
994 .headers()
995 .iter()
996 .find(|h| h.field.as_str() == "Authorization" || h.field.as_str() == "authorization")
997 .and_then(|h| {
998 let val = h.value.as_str();
999 val.strip_prefix("Bearer ").map(|t| t.to_string())
1000 });
1001 let cookie_token: Option<String> = if bearer_token.is_some() {
1002 None
1003 } else {
1004 request
1005 .headers()
1006 .iter()
1007 .find(|h| h.field.as_str() == "Cookie" || h.field.as_str() == "cookie")
1008 .and_then(|h| {
1009 pylon_auth::extract_session_cookie(h.value.as_str(), &cookie_config.name)
1010 })
1011 };
1012 let auth_token: Option<String> = bearer_token.or(cookie_token);
1013 let auth_ctx_result: Result<pylon_auth::AuthContext, &'static str> = if admin_token.is_some()
1022 && auth_token.is_some()
1023 && pylon_auth::constant_time_eq(
1024 auth_token.as_deref().unwrap_or("").as_bytes(),
1025 admin_token.as_deref().unwrap_or("").as_bytes(),
1026 ) {
1027 Ok(pylon_auth::AuthContext::admin())
1028 } else if let Some(t) = auth_token.as_deref() {
1029 if t.starts_with("pk.") {
1030 match ak.verify(t) {
1031 Ok(key) => Ok(pylon_auth::AuthContext::from_api_key(
1032 key.user_id,
1033 key.id,
1034 key.scopes,
1035 )),
1036 Err(_) => Err("INVALID_API_KEY"),
1037 }
1038 } else if pylon_auth::jwt::looks_like_jwt(t) && jwt_secret().is_some() {
1039 let Some(issuer) = jwt_issuer() else {
1045 tracing::warn!(
1046 "[auth] PYLON_JWT_SECRET set but PYLON_JWT_ISSUER missing — \
1047 refusing JWT verify (set both to enable JWT sessions)"
1048 );
1049 Err("JWT_MISCONFIGURED")?;
1050 unreachable!();
1051 };
1052 let secret = jwt_secret().expect("checked above");
1053 match pylon_auth::jwt::verify(t, secret.as_bytes(), Some(issuer)) {
1054 Ok(claims) => {
1055 let mut ctx = pylon_auth::AuthContext::authenticated(claims.sub);
1056 ctx.roles = claims.roles;
1057 if let Some(t) = claims.tenant_id {
1058 ctx = ctx.with_tenant(t);
1059 }
1060 Ok(ctx)
1061 }
1062 Err(_) => Err("INVALID_JWT"),
1063 }
1064 } else {
1065 Ok(ss.resolve(Some(t)))
1066 }
1067 } else {
1068 Ok(ss.resolve(None))
1069 };
1070 let auth_ctx = match auth_ctx_result {
1071 Ok(c) => c,
1072 Err(reason) => {
1073 let body = format!(
1074 r#"{{"error":{{"code":"{reason}","message":"Bearer token is malformed, expired, or revoked"}}}}"#
1075 );
1076 let resp = tiny_http::Response::from_string(body)
1077 .with_status_code(401)
1078 .with_header(
1079 "Content-Type: application/json"
1080 .parse::<tiny_http::Header>()
1081 .unwrap(),
1082 );
1083 let _ = request.respond(resp);
1084 continue;
1085 }
1086 };
1087
1088 if url == "/api/__test__/reset" && method == Method::Post {
1103 let is_loopback = peer_ip == "127.0.0.1"
1104 || peer_ip == "::1"
1105 || peer_ip.starts_with("127.")
1106 || peer_ip == "localhost";
1107 if !is_dev || !rt.is_in_memory() || !is_loopback {
1108 let body = json_error(
1109 "RESET_REFUSED",
1110 "reset endpoint is only available in dev mode + in-memory DB + from loopback",
1111 );
1112 let response = with_security_headers(
1113 Response::from_string(&body)
1114 .with_status_code(403u16)
1115 .with_header(
1116 Header::from_bytes("Content-Type", "application/json").unwrap(),
1117 )
1118 .with_header(
1119 Header::from_bytes(
1120 "Access-Control-Allow-Origin",
1121 cors_origin.as_bytes().to_vec(),
1122 )
1123 .unwrap(),
1124 ),
1125 );
1126 let _ = request.respond(response);
1127 mt.record_request("POST", 403);
1128 continue;
1129 }
1130 let (status, body) = match rt.reset_for_tests() {
1131 Ok(()) => (200u16, "{\"reset\":true}".to_string()),
1132 Err(e) => (500u16, json_error(&e.code, &e.message)),
1133 };
1134 let response = with_security_headers(
1135 Response::from_string(&body)
1136 .with_status_code(status)
1137 .with_header(Header::from_bytes("Content-Type", "application/json").unwrap())
1138 .with_header(
1139 Header::from_bytes(
1140 "Access-Control-Allow-Origin",
1141 cors_origin.as_bytes().to_vec(),
1142 )
1143 .unwrap(),
1144 ),
1145 );
1146 let _ = request.respond(response);
1147 mt.record_request("POST", status);
1148 continue;
1149 }
1150
1151 if url == "/api/files/upload" && method == Method::Post {
1160 const UPLOAD_MAX: usize = 10 * 1024 * 1024;
1161 if let Some(declared) = request.body_length() {
1164 if declared > UPLOAD_MAX {
1165 let err = json_error(
1166 "PAYLOAD_TOO_LARGE",
1167 &format!("Content-Length {declared} exceeds upload max of {UPLOAD_MAX}"),
1168 );
1169 let response = with_security_headers(
1170 Response::from_string(&err)
1171 .with_status_code(413u16)
1172 .with_header(
1173 Header::from_bytes("Content-Type", "application/json").unwrap(),
1174 )
1175 .with_header(
1176 Header::from_bytes(
1177 "Access-Control-Allow-Origin",
1178 cors_origin.as_bytes().to_vec(),
1179 )
1180 .unwrap(),
1181 ),
1182 );
1183 let _ = request.respond(response);
1184 mt.record_request("POST", 413);
1185 continue;
1186 }
1187 }
1188 if auth_ctx.user_id.is_none() {
1189 let err = json_error(
1190 "AUTH_REQUIRED",
1191 "/api/files/upload requires an authenticated session",
1192 );
1193 let response = with_security_headers(
1194 Response::from_string(&err)
1195 .with_status_code(401u16)
1196 .with_header(
1197 Header::from_bytes("Content-Type", "application/json").unwrap(),
1198 )
1199 .with_header(
1200 Header::from_bytes(
1201 "Access-Control-Allow-Origin",
1202 cors_origin.as_bytes().to_vec(),
1203 )
1204 .unwrap(),
1205 ),
1206 );
1207 let _ = request.respond(response);
1208 mt.record_request("POST", 401);
1209 continue;
1210 }
1211 use std::io::Read;
1216 let mut bytes: Vec<u8> = Vec::with_capacity(8192);
1217 let mut limited = request.as_reader().take((UPLOAD_MAX as u64) + 1);
1218 let _ = limited.read_to_end(&mut bytes);
1219
1220 const MAX: usize = UPLOAD_MAX;
1221 if bytes.len() > MAX {
1222 let err = json_error("PAYLOAD_TOO_LARGE", "File exceeds 10 MB limit");
1223 let response = with_security_headers(
1224 Response::from_string(&err)
1225 .with_status_code(413u16)
1226 .with_header(
1227 Header::from_bytes("Content-Type", "application/json").unwrap(),
1228 )
1229 .with_header(
1230 Header::from_bytes(
1231 "Access-Control-Allow-Origin",
1232 cors_origin.as_bytes().to_vec(),
1233 )
1234 .unwrap(),
1235 ),
1236 );
1237 let _ = request.respond(response);
1238 mt.record_request("POST", 413);
1239 continue;
1240 }
1241
1242 let content_type = request
1244 .headers()
1245 .iter()
1246 .find(|h| h.field.as_str() == "Content-Type" || h.field.as_str() == "content-type")
1247 .map(|h| h.value.as_str().to_string())
1248 .unwrap_or_else(|| "application/octet-stream".into());
1249 let filename = request
1250 .headers()
1251 .iter()
1252 .find(|h| h.field.as_str() == "X-Filename" || h.field.as_str() == "x-filename")
1253 .map(|h| h.value.as_str().to_string())
1254 .unwrap_or_else(|| "upload".into());
1255
1256 let (name, ct, payload) = if content_type.starts_with("multipart/form-data") {
1258 match parse_multipart_first_file(&bytes, &content_type) {
1259 Some(p) => p,
1260 None => {
1261 let err = json_error("INVALID_MULTIPART", "Could not parse multipart body");
1262 let response = with_security_headers(
1263 Response::from_string(&err)
1264 .with_status_code(400u16)
1265 .with_header(
1266 Header::from_bytes("Content-Type", "application/json").unwrap(),
1267 )
1268 .with_header(
1269 Header::from_bytes(
1270 "Access-Control-Allow-Origin",
1271 cors_origin.as_bytes().to_vec(),
1272 )
1273 .unwrap(),
1274 ),
1275 );
1276 let _ = request.respond(response);
1277 mt.record_request("POST", 400);
1278 continue;
1279 }
1280 }
1281 } else {
1282 (filename, content_type, bytes)
1283 };
1284
1285 let storage = pylon_storage::files::LocalFileStorage::new(
1286 &std::env::var("PYLON_FILES_DIR").unwrap_or_else(|_| "uploads".into()),
1287 &std::env::var("PYLON_FILES_URL_PREFIX").unwrap_or_else(|_| "/api/files".into()),
1288 );
1289
1290 let (status, body) =
1291 match pylon_storage::files::FileStorage::store(&storage, &name, &payload, &ct) {
1292 Ok(stored) => (
1293 201u16,
1294 serde_json::to_string(&stored).unwrap_or_else(|_| "{}".into()),
1295 ),
1296 Err(e) => (500u16, json_error(&e.code, &e.message)),
1297 };
1298
1299 let response = with_security_headers(
1300 Response::from_string(&body)
1301 .with_status_code(status)
1302 .with_header(Header::from_bytes("Content-Type", "application/json").unwrap())
1303 .with_header(
1304 Header::from_bytes(
1305 "Access-Control-Allow-Origin",
1306 cors_origin.as_bytes().to_vec(),
1307 )
1308 .unwrap(),
1309 ),
1310 );
1311 let _ = request.respond(response);
1312 mt.record_request("POST", status);
1313 continue;
1314 }
1315
1316 const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
1326
1327 if let Some(declared) = request.body_length() {
1328 if declared > MAX_BODY_SIZE {
1329 let err_body = json_error(
1330 "PAYLOAD_TOO_LARGE",
1331 &format!("Content-Length {declared} exceeds max of {MAX_BODY_SIZE}"),
1332 );
1333 let response = with_security_headers(
1334 Response::from_string(&err_body)
1335 .with_status_code(413u16)
1336 .with_header(
1337 Header::from_bytes(
1338 "Access-Control-Allow-Origin",
1339 cors_origin.as_bytes().to_vec(),
1340 )
1341 .unwrap(),
1342 ),
1343 );
1344 let _ = request.respond(response);
1345 mt.record_request(method.as_str(), 413);
1346 continue;
1347 }
1348 }
1349
1350 let mut body = String::new();
1351 if !matches!(
1352 method,
1353 Method::Get | Method::Head | Method::Options | Method::Delete
1354 ) {
1355 use std::io::Read;
1356 let mut limited = request.as_reader().take((MAX_BODY_SIZE as u64) + 1);
1357 let _ = limited.read_to_string(&mut body);
1358 }
1359
1360 if body.len() > MAX_BODY_SIZE {
1361 let err_body = json_error(
1362 "PAYLOAD_TOO_LARGE",
1363 &format!(
1364 "Request body exceeds maximum size of {} bytes",
1365 MAX_BODY_SIZE,
1366 ),
1367 );
1368 let response = with_security_headers(
1369 Response::from_string(&err_body)
1370 .with_status_code(413u16)
1371 .with_header(Header::from_bytes("Content-Type", "application/json").unwrap())
1372 .with_header(
1373 Header::from_bytes(
1374 "Access-Control-Allow-Origin",
1375 cors_origin.as_bytes().to_vec(),
1376 )
1377 .unwrap(),
1378 ),
1379 );
1380 let _ = request.respond(response);
1381 mt.record_request(method.as_str(), 413);
1382 continue;
1383 }
1384
1385 if method == Method::Get {
1389 if let Some(rest) = url.strip_prefix("/api/shards/") {
1390 let rest = rest.split('?').next().unwrap_or(rest);
1391 if let Some(shard_id) = rest.strip_suffix("/connect") {
1392 if auth_ctx.user_id.is_none() {
1397 let err = json_error(
1398 "AUTH_REQUIRED",
1399 "Shard connect requires an authenticated session",
1400 );
1401 let response = with_security_headers(
1402 Response::from_string(&err)
1403 .with_status_code(401u16)
1404 .with_header(
1405 Header::from_bytes("Content-Type", "application/json").unwrap(),
1406 )
1407 .with_header(
1408 Header::from_bytes(
1409 "Access-Control-Allow-Origin",
1410 cors_origin.as_bytes().to_vec(),
1411 )
1412 .unwrap(),
1413 ),
1414 );
1415 let _ = request.respond(response);
1416 mt.record_request("GET", 401);
1417 continue;
1418 }
1419 let shards = match &shards_ref {
1420 Some(s) => Arc::clone(s),
1421 None => {
1422 let err = json_error(
1423 "SHARDS_NOT_AVAILABLE",
1424 "Shard system is not configured",
1425 );
1426 let response = with_security_headers(
1427 Response::from_string(&err)
1428 .with_status_code(503u16)
1429 .with_header(
1430 Header::from_bytes("Content-Type", "application/json")
1431 .unwrap(),
1432 )
1433 .with_header(
1434 Header::from_bytes(
1435 "Access-Control-Allow-Origin",
1436 cors_origin.as_bytes().to_vec(),
1437 )
1438 .unwrap(),
1439 ),
1440 );
1441 let _ = request.respond(response);
1442 mt.record_request("GET", 503);
1443 continue;
1444 }
1445 };
1446 let shard = match shards.get(shard_id) {
1447 Some(s) => s,
1448 None => {
1449 let err = json_error(
1450 "SHARD_NOT_FOUND",
1451 &format!("Shard \"{shard_id}\" not found"),
1452 );
1453 let response = with_security_headers(
1454 Response::from_string(&err)
1455 .with_status_code(404u16)
1456 .with_header(
1457 Header::from_bytes("Content-Type", "application/json")
1458 .unwrap(),
1459 )
1460 .with_header(
1461 Header::from_bytes(
1462 "Access-Control-Allow-Origin",
1463 cors_origin.as_bytes().to_vec(),
1464 )
1465 .unwrap(),
1466 ),
1467 );
1468 let _ = request.respond(response);
1469 mt.record_request("GET", 404);
1470 continue;
1471 }
1472 };
1473
1474 let sub_id = url
1477 .split("sid=")
1478 .nth(1)
1479 .and_then(|s| s.split('&').next())
1480 .map(|s| s.to_string())
1481 .or_else(|| auth_ctx.user_id.clone())
1482 .unwrap_or_else(|| {
1483 format!(
1484 "anon_{}",
1485 std::time::SystemTime::now()
1486 .duration_since(std::time::UNIX_EPOCH)
1487 .unwrap_or_default()
1488 .as_nanos()
1489 )
1490 });
1491 let subscriber_id = pylon_realtime::SubscriberId::new(sub_id);
1492
1493 let (tx, rx) = std::sync::mpsc::channel::<Vec<u8>>();
1494 let streaming_body = StreamingBody::new(rx);
1495
1496 let tx_clone = tx.clone();
1497 let sink: pylon_realtime::SnapshotSink =
1498 Box::new(move |tick: u64, bytes: &[u8]| {
1499 let mut frame = format!("id: {tick}\ndata: ").into_bytes();
1502 frame.extend_from_slice(bytes);
1503 frame.extend_from_slice(b"\n\n");
1504 let _ = tx_clone.send(frame);
1505 });
1506
1507 let shard_auth = pylon_realtime::ShardAuth {
1508 user_id: auth_ctx.user_id.clone(),
1509 is_admin: auth_ctx.is_admin,
1510 };
1511 if let Err(e) = shard.add_subscriber(subscriber_id.clone(), sink, &shard_auth) {
1512 let (status, code) = match &e {
1513 pylon_realtime::ShardError::Unauthorized(_) => (403u16, "UNAUTHORIZED"),
1514 _ => (429u16, "SUBSCRIBE_FAILED"),
1515 };
1516 let err = json_error(code, &e.to_string());
1517 let response = with_security_headers(
1518 Response::from_string(&err)
1519 .with_status_code(status)
1520 .with_header(
1521 Header::from_bytes("Content-Type", "application/json").unwrap(),
1522 )
1523 .with_header(
1524 Header::from_bytes(
1525 "Access-Control-Allow-Origin",
1526 cors_origin.as_bytes().to_vec(),
1527 )
1528 .unwrap(),
1529 ),
1530 );
1531 let _ = request.respond(response);
1532 mt.record_request("GET", status);
1533 continue;
1534 }
1535
1536 {
1539 let shard_cleanup = Arc::clone(&shard);
1540 let sub_id_cleanup = subscriber_id.clone();
1541 let tx_liveness = tx.clone();
1542 std::thread::spawn(move || {
1543 loop {
1546 std::thread::sleep(std::time::Duration::from_secs(30));
1547 if tx_liveness.send(b": heartbeat\n\n".to_vec()).is_err() {
1548 shard_cleanup.remove_subscriber(&sub_id_cleanup);
1549 return;
1550 }
1551 if !shard_cleanup.is_running() {
1552 return;
1553 }
1554 }
1555 });
1556 }
1557
1558 let response = with_security_headers(Response::new(
1559 tiny_http::StatusCode(200),
1560 vec![
1561 Header::from_bytes("Content-Type", "text/event-stream").unwrap(),
1562 Header::from_bytes("Cache-Control", "no-cache").unwrap(),
1563 Header::from_bytes("Connection", "keep-alive").unwrap(),
1564 Header::from_bytes(
1565 "Access-Control-Allow-Origin",
1566 cors_origin.as_bytes().to_vec(),
1567 )
1568 .unwrap(),
1569 ],
1570 streaming_body,
1571 None,
1572 None,
1573 ));
1574 let _ = request.respond(response);
1575 mt.record_request("GET", 200);
1576 continue;
1577 }
1578 }
1579 }
1580
1581 if method == Method::Post
1583 && url.starts_with("/api/fn/")
1584 && url != "/api/fn/traces"
1585 && request.headers().iter().any(|h| {
1586 (h.field.as_str() == "Accept" || h.field.as_str() == "accept")
1587 && h.value.as_str().contains("text/event-stream")
1588 })
1589 {
1590 let fn_name = url
1591 .strip_prefix("/api/fn/")
1592 .unwrap_or("")
1593 .split('?')
1594 .next()
1595 .unwrap_or("")
1596 .to_string();
1597
1598 if let Some(fn_ops) = &fn_ops_maybe {
1599 if pylon_router::FnOps::get_fn(fn_ops.as_ref(), &fn_name).is_none() {
1603 let err = json_error(
1604 "FN_NOT_FOUND",
1605 &format!("Function \"{fn_name}\" is not registered"),
1606 );
1607 let response = with_security_headers(
1608 Response::from_string(&err)
1609 .with_status_code(404u16)
1610 .with_header(
1611 Header::from_bytes("Content-Type", "application/json").unwrap(),
1612 )
1613 .with_header(
1614 Header::from_bytes(
1615 "Access-Control-Allow-Origin",
1616 cors_origin.as_bytes().to_vec(),
1617 )
1618 .unwrap(),
1619 ),
1620 );
1621 let _ = request.respond(response);
1622 mt.record_request("POST", 404);
1623 continue;
1624 }
1625 let identity = auth_ctx.user_id.as_deref().unwrap_or("anon");
1627 if let Err(retry_after) =
1628 pylon_router::FnOps::check_rate_limit(fn_ops.as_ref(), &fn_name, identity)
1629 {
1630 let body = format!(
1631 r#"{{"error":{{"code":"RATE_LIMITED","message":"Function \"{fn_name}\" rate limit exceeded","retry_after_secs":{retry_after}}}}}"#
1632 );
1633 let response = with_security_headers(
1634 Response::from_string(&body)
1635 .with_status_code(429u16)
1636 .with_header(
1637 Header::from_bytes("Content-Type", "application/json").unwrap(),
1638 )
1639 .with_header(
1640 Header::from_bytes(
1641 "Access-Control-Allow-Origin",
1642 cors_origin.as_bytes().to_vec(),
1643 )
1644 .unwrap(),
1645 ),
1646 );
1647 let _ = request.respond(response);
1648 mt.record_request("POST", 429);
1649 continue;
1650 }
1651
1652 let args: serde_json::Value =
1653 serde_json::from_str(&body).unwrap_or(serde_json::json!({}));
1654
1655 let auth = pylon_functions::protocol::AuthInfo {
1656 user_id: auth_ctx.user_id.clone(),
1657 is_admin: auth_ctx.is_admin,
1658 tenant_id: auth_ctx.tenant_id.clone(),
1659 };
1660
1661 let (tx, rx) = std::sync::mpsc::channel::<Vec<u8>>();
1662 let streaming_body = StreamingBody::new(rx);
1663
1664 let fn_ops_cl = Arc::clone(fn_ops);
1665 let tx_stream = tx.clone();
1666 std::thread::spawn(move || {
1667 let tx_cb = tx_stream.clone();
1668 let on_stream: Box<dyn FnMut(&str) + Send> = Box::new(move |chunk: &str| {
1669 let sse = format!("data: {}\n\n", chunk);
1670 let _ = tx_cb.send(sse.into_bytes());
1671 });
1672
1673 let result = pylon_router::FnOps::call(
1674 fn_ops_cl.as_ref(),
1675 &fn_name,
1676 args,
1677 auth,
1678 Some(on_stream),
1679 None, );
1681 match result {
1682 Ok((value, _trace)) => {
1683 let done = format!(
1684 "event: result\ndata: {}\n\n",
1685 serde_json::to_string(&value).unwrap_or_else(|_| "null".into())
1686 );
1687 let _ = tx_stream.send(done.into_bytes());
1688 }
1689 Err(e) => {
1690 let err = format!(
1691 "event: error\ndata: {}\n\n",
1692 serde_json::json!({"code": e.code, "message": e.message})
1693 );
1694 let _ = tx_stream.send(err.into_bytes());
1695 }
1696 }
1697 });
1698
1699 let response = with_security_headers(Response::new(
1700 tiny_http::StatusCode(200),
1701 vec![
1702 Header::from_bytes("Content-Type", "text/event-stream").unwrap(),
1703 Header::from_bytes("Cache-Control", "no-cache").unwrap(),
1704 Header::from_bytes("Connection", "keep-alive").unwrap(),
1705 Header::from_bytes(
1706 "Access-Control-Allow-Origin",
1707 cors_origin.as_bytes().to_vec(),
1708 )
1709 .unwrap(),
1710 ],
1711 streaming_body,
1712 None,
1713 None,
1714 ));
1715 let _ = request.respond(response);
1716 mt.record_request("POST", 200);
1717 continue;
1718 }
1719 }
1720
1721 if url == "/api/ai/stream" && method == Method::Post {
1723 if auth_ctx.user_id.is_none() {
1726 let err = json_error(
1727 "AUTH_REQUIRED",
1728 "/api/ai/stream requires an authenticated session",
1729 );
1730 let response = with_security_headers(
1731 Response::from_string(&err)
1732 .with_status_code(401u16)
1733 .with_header(
1734 Header::from_bytes("Content-Type", "application/json").unwrap(),
1735 )
1736 .with_header(
1737 Header::from_bytes(
1738 "Access-Control-Allow-Origin",
1739 cors_origin.as_bytes().to_vec(),
1740 )
1741 .unwrap(),
1742 ),
1743 );
1744 let _ = request.respond(response);
1745 mt.record_request("POST", 401);
1746 continue;
1747 }
1748 let ai_provider = std::env::var("PYLON_AI_PROVIDER").unwrap_or_default();
1749 let ai_key = std::env::var("PYLON_AI_API_KEY").unwrap_or_default();
1750 let ai_model = std::env::var("PYLON_AI_MODEL").unwrap_or_default();
1751 let ai_base = std::env::var("PYLON_AI_BASE_URL").unwrap_or_default();
1752
1753 if ai_key.is_empty() && ai_provider != "custom" {
1754 let err = json_error(
1755 "AI_NOT_CONFIGURED",
1756 "Set PYLON_AI_PROVIDER and PYLON_AI_API_KEY",
1757 );
1758 let response = with_security_headers(
1759 Response::from_string(&err)
1760 .with_status_code(503u16)
1761 .with_header(
1762 Header::from_bytes("Content-Type", "application/json").unwrap(),
1763 )
1764 .with_header(
1765 Header::from_bytes(
1766 "Access-Control-Allow-Origin",
1767 cors_origin.as_bytes().to_vec(),
1768 )
1769 .unwrap(),
1770 ),
1771 );
1772 let _ = request.respond(response);
1773 mt.record_request("POST", 503);
1774 continue;
1775 }
1776
1777 let parsed: serde_json::Value = match serde_json::from_str(&body) {
1778 Ok(v) => v,
1779 Err(_) => {
1780 let err = json_error("INVALID_JSON", "Invalid request body");
1781 let response = with_security_headers(
1782 Response::from_string(&err)
1783 .with_status_code(400u16)
1784 .with_header(
1785 Header::from_bytes("Content-Type", "application/json").unwrap(),
1786 )
1787 .with_header(
1788 Header::from_bytes(
1789 "Access-Control-Allow-Origin",
1790 cors_origin.as_bytes().to_vec(),
1791 )
1792 .unwrap(),
1793 ),
1794 );
1795 let _ = request.respond(response);
1796 mt.record_request("POST", 400);
1797 continue;
1798 }
1799 };
1800
1801 let messages: Vec<AiMessage> = match parsed.get("messages").and_then(|m| m.as_array()) {
1802 Some(arr) => arr
1803 .iter()
1804 .filter_map(|m| {
1805 let role = m.get("role")?.as_str()?.to_string();
1806 let content = m.get("content")?.as_str()?.to_string();
1807 Some(AiMessage { role, content })
1808 })
1809 .collect(),
1810 None => {
1811 let err = json_error("MISSING_FIELD", "\"messages\" array is required");
1812 let response = with_security_headers(
1813 Response::from_string(&err)
1814 .with_status_code(400u16)
1815 .with_header(
1816 Header::from_bytes("Content-Type", "application/json").unwrap(),
1817 )
1818 .with_header(
1819 Header::from_bytes(
1820 "Access-Control-Allow-Origin",
1821 cors_origin.as_bytes().to_vec(),
1822 )
1823 .unwrap(),
1824 ),
1825 );
1826 let _ = request.respond(response);
1827 mt.record_request("POST", 400);
1828 continue;
1829 }
1830 };
1831
1832 let model = parsed
1834 .get("model")
1835 .and_then(|m| m.as_str())
1836 .map(|s| s.to_string())
1837 .unwrap_or(ai_model);
1838
1839 let proxy = match ai_provider.as_str() {
1840 "anthropic" => AiProxyPlugin::anthropic(&ai_key, &model),
1841 "openai" => AiProxyPlugin::openai(&ai_key, &model),
1842 "custom" => AiProxyPlugin::custom_with_model(&ai_base, &ai_key, &model),
1843 _ => AiProxyPlugin::openai(&ai_key, &model),
1844 };
1845
1846 let (tx, rx) = std::sync::mpsc::channel::<Vec<u8>>();
1849 let streaming_body = StreamingBody::new(rx);
1850
1851 std::thread::spawn(move || {
1854 let result = proxy.stream_completion(&messages, &mut |chunk| {
1855 let sse = format!(
1856 "data: {}
1857
1858",
1859 serde_json::json!({
1860 "choices": [{"index": 0, "delta": {"content": chunk}}]
1861 })
1862 );
1863 let _ = tx.send(sse.into_bytes());
1864 });
1865
1866 match result {
1868 Ok(_) => {
1869 let _ = tx.send(
1870 b"data: [DONE]
1871
1872"
1873 .to_vec(),
1874 );
1875 }
1876 Err(e) => {
1877 let err_event = format!(
1878 "data: {}
1879
1880",
1881 serde_json::json!({"error": {"message": e, "type": "stream_error"}})
1882 );
1883 let _ = tx.send(err_event.into_bytes());
1884 }
1885 }
1886 });
1888
1889 let response = with_security_headers(Response::new(
1890 tiny_http::StatusCode(200),
1891 vec![
1892 Header::from_bytes("Content-Type", "text/event-stream").unwrap(),
1893 Header::from_bytes("Cache-Control", "no-cache").unwrap(),
1894 Header::from_bytes("Connection", "keep-alive").unwrap(),
1895 Header::from_bytes(
1896 "Access-Control-Allow-Origin",
1897 cors_origin.as_bytes().to_vec(),
1898 )
1899 .unwrap(),
1900 ],
1901 streaming_body,
1902 None, None,
1904 ));
1905 let _ = request.respond(response);
1906 mt.record_request("POST", 200);
1907 continue;
1908 }
1909
1910 let (status, response_body, content_type, is_studio, extra_headers) = if (url == "/studio"
1921 || url == "/studio/")
1922 && method == Method::Get
1923 {
1924 if !is_dev && !auth_ctx.is_admin {
1925 let body = json_error(
1926 "AUTH_REQUIRED",
1927 "/studio requires admin auth in production (set PYLON_ADMIN_TOKEN and pass it as Bearer)",
1928 );
1929 let response = with_security_headers(
1930 Response::from_string(&body)
1931 .with_status_code(401u16)
1932 .with_header(
1933 Header::from_bytes("Content-Type", "application/json").unwrap(),
1934 )
1935 .with_header(
1936 Header::from_bytes(
1937 "Access-Control-Allow-Origin",
1938 cors_origin.as_bytes().to_vec(),
1939 )
1940 .unwrap(),
1941 ),
1942 );
1943 let _ = request.respond(response);
1944 mt.record_request("GET", 401);
1945 continue;
1946 }
1947 let host = request
1954 .headers()
1955 .iter()
1956 .find(|h| h.field.equiv("Host"))
1957 .map(|h| h.value.as_str().to_string())
1958 .unwrap_or_else(|| format!("localhost:{port}"));
1959 let scheme = request
1960 .headers()
1961 .iter()
1962 .find(|h| h.field.equiv("X-Forwarded-Proto"))
1963 .map(|h| h.value.as_str().to_string())
1964 .unwrap_or_else(|| "http".to_string());
1965 let base = format!("{scheme}://{host}");
1966 let html = pylon_studio_api::generate_studio_html(rt.manifest(), &base);
1967 (
1968 200u16,
1969 html,
1970 "text/html",
1971 true,
1972 Vec::<(String, String)>::new(),
1973 )
1974 } else {
1975 let meta = pylon_plugin::RequestMeta {
1979 peer_ip: peer_ip.as_str(),
1980 };
1981 if let Err(e) = pr.run_on_request_with_meta(method.as_str(), &url, &auth_ctx, &meta) {
1982 (
1983 e.status,
1984 json_error(&e.code, &e.message),
1985 "application/json",
1986 false,
1987 Vec::new(),
1988 )
1989 } else if let Some((s, b)) =
1990 pr.try_handle_route(method.as_str(), &url, &body, &auth_ctx)
1991 {
1992 (s, b, "application/json", false, Vec::new())
1994 } else {
1995 let notifier = WsSseNotifier {
1996 ws: Arc::clone(&wh),
1997 sse: Arc::clone(&sh),
1998 };
1999 let openapi_gen = RuntimeOpenApiGenerator {
2000 manifest: rt.manifest(),
2001 };
2002 let file_ops = LocalFileOps::new_default();
2003 let cache_adapter = CacheAdapter(Arc::clone(&ca));
2004 let pubsub_adapter = PubSubAdapter(Arc::clone(&ps));
2005 let email_adapter = EmailAdapter::from_env();
2006 let fn_ops: Option<&dyn pylon_router::FnOps> =
2007 fn_ops_ref.as_deref().map(|f| f as &dyn pylon_router::FnOps);
2008 let shard_adapter = shards_ref.as_ref().map(|reg| ShardOpsAdapter {
2009 registry: Arc::clone(reg),
2010 });
2011 let shard_ops: Option<&dyn pylon_router::ShardOps> = shard_adapter
2012 .as_ref()
2013 .map(|a| a as &dyn pylon_router::ShardOps);
2014 let plugin_hooks = PluginHooksAdapter(Arc::clone(&pr));
2015 let request_headers: Vec<(String, String)> = request
2020 .headers()
2021 .iter()
2022 .map(|h| (h.field.as_str().to_string(), h.value.as_str().to_string()))
2023 .collect();
2024 let router_ctx = pylon_router::RouterContext {
2025 store: rt.as_ref(),
2026 session_store: &ss,
2027 magic_codes: &mc,
2028 oauth_state: &os,
2029 account_store: &acc,
2030 api_keys: &ak,
2031 orgs: &og,
2032 siwe: &sw,
2033 phone_codes: &pcd,
2034 passkeys: &pks,
2035 policy_engine: &pe,
2036 change_log: &cl,
2037 notifier: ¬ifier,
2038 rooms: rm.as_ref(),
2039 cache: &cache_adapter,
2040 pubsub: &pubsub_adapter,
2041 jobs: jq.as_ref(),
2042 scheduler: sc.as_ref(),
2043 workflows: we.as_ref(),
2044 files: &file_ops,
2045 openapi: &openapi_gen,
2046 functions: fn_ops,
2047 email: &email_adapter,
2048 shards: shard_ops,
2049 plugin_hooks: &plugin_hooks,
2050 auth_ctx: &auth_ctx,
2051 trusted_origins: &trusted_origins_ref,
2052 is_dev,
2053 request_headers: &request_headers,
2054 peer_ip: peer_ip.as_str(),
2055 cookie_config: cookie_config.as_ref(),
2056 response_headers: std::cell::RefCell::new(Vec::new()),
2057 };
2058 let http_method = HttpMethod::from_str(method.as_str());
2059 let (s, b, _ct) = pylon_router::route(
2060 &router_ctx,
2061 http_method,
2062 &url,
2063 &body,
2064 auth_token.as_deref(),
2065 );
2066 let extra_headers = router_ctx.take_response_headers();
2067 (s, b, "application/json", false, extra_headers)
2068 }
2069 };
2070
2071 let mut response = Response::from_string(&response_body)
2072 .with_status_code(status)
2073 .with_header(Header::from_bytes("Content-Type", content_type).unwrap())
2074 .with_header(
2075 Header::from_bytes(
2076 "Access-Control-Allow-Origin",
2077 cors_origin.as_bytes().to_vec(),
2078 )
2079 .unwrap(),
2080 )
2081 .with_header(
2082 Header::from_bytes(
2083 "Access-Control-Allow-Methods",
2084 "GET, POST, PATCH, DELETE, OPTIONS",
2085 )
2086 .unwrap(),
2087 )
2088 .with_header(
2089 Header::from_bytes(
2090 "Access-Control-Allow-Headers",
2091 "Content-Type, Authorization",
2092 )
2093 .unwrap(),
2094 );
2095 if allow_credentials {
2100 response = response
2101 .with_header(
2102 Header::from_bytes("Access-Control-Allow-Credentials", "true").unwrap(),
2103 )
2104 .with_header(Header::from_bytes("Vary", "Origin").unwrap());
2105 }
2106
2107 for (name, value) in extra_headers {
2114 if let Ok(h) = Header::from_bytes(name.as_bytes(), value.as_bytes().to_vec()) {
2115 response = response.with_header(h);
2116 }
2117 }
2118
2119 if is_studio {
2132 response = response.with_header(
2133 Header::from_bytes(
2134 "Content-Security-Policy",
2135 "default-src 'self' 'unsafe-inline' 'unsafe-eval' https://cdn.tailwindcss.com https://unpkg.com ws: wss:",
2136 ).unwrap(),
2137 );
2138 }
2139
2140 let response = with_security_headers(response);
2141
2142 let _ = request.respond(response);
2143 mt.record_request(method.as_str(), status);
2144 }
2145
2146 tracing::warn!("Shutting down gracefully...");
2147
2148 let drain_timeout = std::time::Duration::from_secs(
2151 std::env::var("PYLON_DRAIN_SECS")
2152 .ok()
2153 .and_then(|s| s.parse().ok())
2154 .unwrap_or(10),
2155 );
2156 let start = Instant::now();
2157
2158 if let Some(reg) = &shard_registry {
2160 for id in reg.ids() {
2161 if let Some(shard) = reg.get(&id) {
2162 shard.stop();
2163 }
2164 }
2165 }
2166
2167 let _ = &scheduler; while start.elapsed() < drain_timeout {
2172 let pending_jobs = job_queue.stats().pending;
2173 if pending_jobs == 0 {
2174 break;
2175 }
2176 std::thread::sleep(std::time::Duration::from_millis(100));
2177 }
2178
2179 let elapsed = start.elapsed();
2180 tracing::warn!(
2181 "Drain complete in {:.1}s (timeout {}s)",
2182 elapsed.as_secs_f32(),
2183 drain_timeout.as_secs()
2184 );
2185 Ok(())
2186}
2187
2188fn json_error(code: &str, message: &str) -> String {
2193 pylon_router::json_error(code, message)
2194}
2195
2196struct AuthStores {
2206 session_store: Arc<SessionStore>,
2207 magic_codes: Arc<pylon_auth::MagicCodeStore>,
2208 oauth_state: Arc<pylon_auth::OAuthStateStore>,
2209 account_store: Arc<pylon_auth::AccountStore>,
2210 api_keys: Arc<pylon_auth::api_key::ApiKeyStore>,
2211 orgs: Arc<pylon_auth::org::OrgStore>,
2212 siwe: Arc<pylon_auth::siwe::NonceStore>,
2213 phone_codes: Arc<pylon_auth::phone::PhoneCodeStore>,
2214 passkeys: Arc<pylon_auth::webauthn::PasskeyStore>,
2215}
2216
2217fn jwt_secret() -> Option<&'static String> {
2222 static CELL: std::sync::OnceLock<Option<String>> = std::sync::OnceLock::new();
2223 CELL.get_or_init(|| std::env::var("PYLON_JWT_SECRET").ok().filter(|s| !s.is_empty()))
2224 .as_ref()
2225}
2226
2227fn jwt_issuer() -> Option<&'static String> {
2228 static CELL: std::sync::OnceLock<Option<String>> = std::sync::OnceLock::new();
2229 CELL.get_or_init(|| std::env::var("PYLON_JWT_ISSUER").ok().filter(|s| !s.is_empty()))
2230 .as_ref()
2231}
2232
2233fn build_auth_stores(app_db_path: Option<&str>, session_lifetime: u64) -> AuthStores {
2234 let force_in_memory = std::env::var("PYLON_SESSION_IN_MEMORY")
2237 .map(|v| v == "1" || v == "true")
2238 .unwrap_or(false);
2239
2240 let pg_url = std::env::var("DATABASE_URL")
2243 .ok()
2244 .filter(|u| u.starts_with("postgres://") || u.starts_with("postgresql://"));
2245
2246 if let Some(url) = pg_url {
2247 if force_in_memory {
2248 return in_memory_auth_stores(session_lifetime);
2251 }
2252 return build_pg_auth_stores(&url, session_lifetime);
2253 }
2254
2255 let sqlite_path = std::env::var("PYLON_SESSION_DB")
2256 .ok()
2257 .or_else(|| app_db_path.map(|p| format!("{p}.sessions.db")));
2258
2259 match (force_in_memory, sqlite_path) {
2260 (true, _) | (_, None) => in_memory_auth_stores(session_lifetime),
2261 (false, Some(path)) => build_sqlite_auth_stores(&path, session_lifetime),
2262 }
2263}
2264
2265fn in_memory_auth_stores(session_lifetime: u64) -> AuthStores {
2266 AuthStores {
2267 session_store: Arc::new(SessionStore::new().with_lifetime(session_lifetime)),
2268 magic_codes: Arc::new(pylon_auth::MagicCodeStore::new()),
2269 oauth_state: Arc::new(pylon_auth::OAuthStateStore::new()),
2270 account_store: Arc::new(pylon_auth::AccountStore::new()),
2271 api_keys: Arc::new(pylon_auth::api_key::ApiKeyStore::new()),
2272 orgs: Arc::new(pylon_auth::org::OrgStore::new()),
2273 siwe: Arc::new(pylon_auth::siwe::NonceStore::new()),
2274 phone_codes: Arc::new(pylon_auth::phone::PhoneCodeStore::new()),
2275 passkeys: Arc::new(pylon_auth::webauthn::PasskeyStore::new()),
2276 }
2277}
2278
2279fn build_sqlite_auth_stores(path: &str, session_lifetime: u64) -> AuthStores {
2280 let session_store = match crate::session_backend::SqliteSessionBackend::open(path) {
2281 Ok(b) => {
2282 tracing::info!("[pylon] Auth state (SQLite): {path}");
2283 SessionStore::with_backend(Box::new(b)).with_lifetime(session_lifetime)
2284 }
2285 Err(e) => {
2286 tracing::warn!("[pylon] could not open session DB {path}: {e}. In-memory fallback.");
2287 SessionStore::new().with_lifetime(session_lifetime)
2288 }
2289 };
2290 let magic_codes = match crate::magic_code_backend::SqliteMagicCodeBackend::open(path) {
2291 Ok(b) => pylon_auth::MagicCodeStore::with_backend(Box::new(b)),
2292 Err(e) => {
2293 tracing::warn!("[pylon] magic-code SQLite backend unavailable: {e}");
2294 pylon_auth::MagicCodeStore::new()
2295 }
2296 };
2297 let oauth_state = match crate::oauth_backend::SqliteOAuthBackend::open(path) {
2298 Ok(b) => pylon_auth::OAuthStateStore::with_backend(Box::new(b)),
2299 Err(e) => {
2300 tracing::warn!("[pylon] OAuth state SQLite backend unavailable: {e}");
2301 pylon_auth::OAuthStateStore::new()
2302 }
2303 };
2304 let account_store = match crate::account_backend::SqliteAccountBackend::open(path) {
2305 Ok(b) => pylon_auth::AccountStore::with_backend(Box::new(b)),
2306 Err(e) => {
2307 tracing::warn!("[pylon] account-link SQLite backend unavailable: {e}");
2308 pylon_auth::AccountStore::new()
2309 }
2310 };
2311 let api_keys = match crate::api_key_backend::SqliteApiKeyBackend::open(path) {
2312 Ok(b) => pylon_auth::api_key::ApiKeyStore::with_backend(Box::new(b)),
2313 Err(e) => {
2314 tracing::warn!("[pylon] api-key SQLite backend unavailable: {e}");
2315 pylon_auth::api_key::ApiKeyStore::new()
2316 }
2317 };
2318 let orgs = match crate::org_backend::SqliteOrgBackend::open(path) {
2319 Ok(b) => pylon_auth::org::OrgStore::with_backend(Box::new(b)),
2320 Err(e) => {
2321 tracing::warn!("[pylon] org SQLite backend unavailable: {e}");
2322 pylon_auth::org::OrgStore::new()
2323 }
2324 };
2325 AuthStores {
2326 session_store: Arc::new(session_store),
2327 magic_codes: Arc::new(magic_codes),
2328 oauth_state: Arc::new(oauth_state),
2329 account_store: Arc::new(account_store),
2330 api_keys: Arc::new(api_keys),
2331 orgs: Arc::new(orgs),
2332 siwe: Arc::new(pylon_auth::siwe::NonceStore::new()),
2333 phone_codes: Arc::new(pylon_auth::phone::PhoneCodeStore::new()),
2334 passkeys: Arc::new(pylon_auth::webauthn::PasskeyStore::new()),
2335 }
2336}
2337
2338fn build_pg_auth_stores(url: &str, session_lifetime: u64) -> AuthStores {
2339 let session_store = match crate::session_backend::PostgresSessionBackend::connect(url) {
2344 Ok(b) => {
2345 tracing::info!("[pylon] Auth state (Postgres): {url}");
2346 SessionStore::with_backend(Box::new(b)).with_lifetime(session_lifetime)
2347 }
2348 Err(e) => {
2349 tracing::warn!("[pylon] PG session backend unavailable: {e}. In-memory fallback.");
2350 SessionStore::new().with_lifetime(session_lifetime)
2351 }
2352 };
2353 let magic_codes = match crate::magic_code_backend::PostgresMagicCodeBackend::connect(url) {
2354 Ok(b) => pylon_auth::MagicCodeStore::with_backend(Box::new(b)),
2355 Err(e) => {
2356 tracing::warn!("[pylon] PG magic-code backend unavailable: {e}");
2357 pylon_auth::MagicCodeStore::new()
2358 }
2359 };
2360 let oauth_state = match crate::oauth_backend::PostgresOAuthBackend::connect(url) {
2361 Ok(b) => pylon_auth::OAuthStateStore::with_backend(Box::new(b)),
2362 Err(e) => {
2363 tracing::warn!("[pylon] PG OAuth state backend unavailable: {e}");
2364 pylon_auth::OAuthStateStore::new()
2365 }
2366 };
2367 let account_store = match crate::account_backend::PostgresAccountBackend::connect(url) {
2368 Ok(b) => pylon_auth::AccountStore::with_backend(Box::new(b)),
2369 Err(e) => {
2370 tracing::warn!("[pylon] PG account-link backend unavailable: {e}");
2371 pylon_auth::AccountStore::new()
2372 }
2373 };
2374 let api_keys = match crate::api_key_backend::PostgresApiKeyBackend::connect(url) {
2375 Ok(b) => pylon_auth::api_key::ApiKeyStore::with_backend(Box::new(b)),
2376 Err(e) => {
2377 tracing::warn!("[pylon] PG api-key backend unavailable: {e}");
2378 pylon_auth::api_key::ApiKeyStore::new()
2379 }
2380 };
2381 let orgs = match crate::org_backend::PostgresOrgBackend::connect(url) {
2382 Ok(b) => pylon_auth::org::OrgStore::with_backend(Box::new(b)),
2383 Err(e) => {
2384 tracing::warn!("[pylon] PG org backend unavailable: {e}");
2385 pylon_auth::org::OrgStore::new()
2386 }
2387 };
2388 AuthStores {
2389 session_store: Arc::new(session_store),
2390 magic_codes: Arc::new(magic_codes),
2391 oauth_state: Arc::new(oauth_state),
2392 account_store: Arc::new(account_store),
2393 api_keys: Arc::new(api_keys),
2394 orgs: Arc::new(orgs),
2395 siwe: Arc::new(pylon_auth::siwe::NonceStore::new()),
2396 phone_codes: Arc::new(pylon_auth::phone::PhoneCodeStore::new()),
2397 passkeys: Arc::new(pylon_auth::webauthn::PasskeyStore::new()),
2398 }
2399}
2400
2401#[allow(dead_code)]
2411fn build_session_store(app_db_path: Option<&str>) -> SessionStore {
2412 if std::env::var("PYLON_SESSION_IN_MEMORY")
2413 .map(|v| v == "1" || v == "true")
2414 .unwrap_or(false)
2415 {
2416 return SessionStore::new();
2417 }
2418 let explicit = std::env::var("PYLON_SESSION_DB").ok();
2419 let default_path = app_db_path.map(|p| format!("{p}.sessions.db"));
2420 let path = match explicit.or(default_path) {
2421 Some(p) => p,
2422 None => return SessionStore::new(),
2423 };
2424 match crate::session_backend::SqliteSessionBackend::open(&path) {
2425 Ok(backend) => {
2426 tracing::info!("[pylon] Session persistence enabled: {path}");
2427 SessionStore::with_backend(Box::new(backend))
2428 }
2429 Err(e) => {
2430 tracing::warn!(
2431 "[pylon] could not open session DB {path}: {e}. Falling back to in-memory sessions."
2432 );
2433 SessionStore::new()
2434 }
2435 }
2436}
2437
2438fn parse_multipart_first_file(
2449 body: &[u8],
2450 content_type_header: &str,
2451) -> Option<(String, String, Vec<u8>)> {
2452 let boundary_param = content_type_header
2454 .split(';')
2455 .find_map(|p| p.trim().strip_prefix("boundary="))?;
2456 let boundary = boundary_param.trim_matches('"');
2457 let delimiter = format!("--{boundary}");
2458 let delimiter_bytes = delimiter.as_bytes();
2459
2460 let mut pos = 0usize;
2462 while pos < body.len() {
2463 let next = find_subslice(&body[pos..], delimiter_bytes)?;
2465 let part_start = pos + next + delimiter_bytes.len();
2466 if part_start + 2 > body.len() {
2468 return None;
2469 }
2470 if &body[part_start..part_start + 2] == b"--" {
2471 return None; }
2473 let header_start = part_start + skip_crlf(&body[part_start..]);
2474
2475 let header_end_offset = find_subslice(&body[header_start..], b"\r\n\r\n")?;
2477 let headers = &body[header_start..header_start + header_end_offset];
2478 let data_start = header_start + header_end_offset + 4;
2479
2480 let next_delim_offset = find_subslice(&body[data_start..], delimiter_bytes)?;
2482 let mut data_end = data_start + next_delim_offset;
2484 if data_end >= 2 && &body[data_end - 2..data_end] == b"\r\n" {
2485 data_end -= 2;
2486 }
2487
2488 let headers_str = std::str::from_utf8(headers).ok()?;
2490 let mut filename: Option<String> = None;
2491 let mut part_ct = String::from("application/octet-stream");
2492 let mut has_file = false;
2493 for line in headers_str.split("\r\n") {
2494 let lower = line.to_ascii_lowercase();
2495 if let Some(rest) = lower.strip_prefix("content-disposition:") {
2496 if rest.contains("filename=") {
2497 has_file = true;
2498 if let Some(start) = line.find("filename=\"") {
2500 let from = start + 10;
2501 if let Some(end_offset) = line[from..].find('"') {
2502 filename = Some(line[from..from + end_offset].to_string());
2503 }
2504 }
2505 }
2506 } else if let Some(rest) = lower.strip_prefix("content-type:") {
2507 part_ct = rest.trim().to_string();
2508 }
2509 }
2510
2511 if has_file {
2512 let name = filename.unwrap_or_else(|| "upload".into());
2513 return Some((name, part_ct, body[data_start..data_end].to_vec()));
2514 }
2515
2516 pos = data_end;
2517 }
2518 None
2519}
2520
2521fn find_subslice(haystack: &[u8], needle: &[u8]) -> Option<usize> {
2522 if needle.is_empty() || needle.len() > haystack.len() {
2523 return None;
2524 }
2525 haystack.windows(needle.len()).position(|w| w == needle)
2526}
2527
2528fn skip_crlf(buf: &[u8]) -> usize {
2529 if buf.len() >= 2 && &buf[0..2] == b"\r\n" {
2530 2
2531 } else if !buf.is_empty() && buf[0] == b'\n' {
2532 1
2533 } else {
2534 0
2535 }
2536}
2537
2538#[cfg(test)]
2539mod multipart_tests {
2540 use super::*;
2541
2542 #[test]
2543 fn parses_single_file() {
2544 let body = b"--bnd\r\n\
2545Content-Disposition: form-data; name=\"file\"; filename=\"hello.txt\"\r\n\
2546Content-Type: text/plain\r\n\
2547\r\n\
2548Hello world\r\n\
2549--bnd--\r\n";
2550 let ct = "multipart/form-data; boundary=bnd";
2551 let (name, content_type, bytes) = parse_multipart_first_file(body, ct).unwrap();
2552 assert_eq!(name, "hello.txt");
2553 assert_eq!(content_type, "text/plain");
2554 assert_eq!(bytes, b"Hello world");
2555 }
2556
2557 #[test]
2558 fn returns_none_without_file_part() {
2559 let body = b"--bnd\r\n\
2560Content-Disposition: form-data; name=\"field\"\r\n\
2561\r\n\
2562just text\r\n\
2563--bnd--\r\n";
2564 let ct = "multipart/form-data; boundary=bnd";
2565 assert!(parse_multipart_first_file(body, ct).is_none());
2566 }
2567
2568 #[test]
2569 fn returns_none_when_no_boundary() {
2570 assert!(parse_multipart_first_file(b"anything", "application/json").is_none());
2571 }
2572}