1use anyhow::Result;
9use axum::{
10 Json, Router,
11 extract::{Path, Query, Request, State},
12 http::StatusCode,
13 middleware::{self, Next},
14 response::{IntoResponse, Response as AxumResponse},
15 routing::{get, post},
16};
17use serde::Deserialize;
18use std::collections::HashMap;
19use std::net::IpAddr;
20use std::sync::Arc;
21
22use crate::jobstore::{JobDir, JobNotFound, generate_job_id, resolve_root};
23use crate::schema::{JobMeta, JobMetaJob, Response, RunData, SCHEMA_VERSION, StatusData, TailData};
24
25pub struct ServeOpts {
27 pub bind: String,
28 pub root: Option<String>,
29 pub insecure: bool,
30 pub allow_origin: Option<String>,
31}
32
33#[derive(Clone)]
34struct AppState {
35 root: Option<String>,
36 token: Option<String>,
37 allow_origin: Option<String>,
38}
39
40pub fn is_loopback(addr: &std::net::SocketAddr) -> bool {
42 match addr.ip() {
43 IpAddr::V4(v4) => v4.is_loopback(),
44 IpAddr::V6(v6) => v6.is_loopback(),
45 }
46}
47
48pub fn execute(opts: ServeOpts) -> Result<()> {
50 let addr: std::net::SocketAddr = opts
51 .bind
52 .parse()
53 .map_err(|e| anyhow::anyhow!("invalid bind address '{}': {e}", opts.bind))?;
54
55 if !is_loopback(&addr) {
56 if !opts.insecure {
57 let err = error_json(
58 "serve_unsafe_bind",
59 &format!("refusing to bind to non-loopback address {addr} without --insecure"),
60 );
61 eprintln!("Error: non-loopback bind address {addr} requires --insecure flag");
62 println!("{}", serde_json::to_string(&err).unwrap());
63 std::process::exit(1);
64 }
65
66 let token = std::env::var("AGENT_EXEC_SERVE_TOKEN").ok();
67 if token.as_ref().is_none_or(|t| t.is_empty()) {
68 let err = error_json(
69 "serve_unsafe_bind",
70 &format!(
71 "refusing to bind to non-loopback address {addr} without AGENT_EXEC_SERVE_TOKEN"
72 ),
73 );
74 eprintln!(
75 "Error: non-loopback bind address {addr} requires AGENT_EXEC_SERVE_TOKEN to be set"
76 );
77 println!("{}", serde_json::to_string(&err).unwrap());
78 std::process::exit(1);
79 }
80 }
81
82 if let Some(ref origin) = opts.allow_origin
83 && origin == "*"
84 {
85 let err = error_json("invalid_config", "wildcard '*' origin is not allowed");
86 eprintln!("Error: --allow-origin '*' is not permitted");
87 println!("{}", serde_json::to_string(&err).unwrap());
88 std::process::exit(1);
89 }
90
91 let rt = tokio::runtime::Runtime::new()?;
92 rt.block_on(async_main(opts, addr))
93}
94
95async fn async_main(opts: ServeOpts, addr: std::net::SocketAddr) -> Result<()> {
96 let token = std::env::var("AGENT_EXEC_SERVE_TOKEN")
97 .ok()
98 .filter(|t| !t.is_empty());
99
100 let state = Arc::new(AppState {
101 root: opts.root,
102 token,
103 allow_origin: opts.allow_origin,
104 });
105
106 let mutating_routes = Router::new()
107 .route("/exec", post(exec_handler))
108 .route("/kill/{id}", post(kill_handler))
109 .layer(middleware::from_fn_with_state(
110 state.clone(),
111 auth_middleware,
112 ));
113
114 let readonly_routes = Router::new()
115 .route("/health", get(health_handler))
116 .route("/status/{id}", get(status_handler))
117 .route("/tail/{id}", get(tail_handler))
118 .route("/wait/{id}", get(wait_handler));
119
120 let mut router = Router::new()
121 .merge(mutating_routes)
122 .merge(readonly_routes)
123 .with_state(state.clone());
124
125 if let Some(ref origin) = state.allow_origin {
126 use tower_http::cors::CorsLayer;
127 let cors = CorsLayer::new()
128 .allow_origin(
129 origin
130 .parse::<axum::http::HeaderValue>()
131 .map_err(|e| anyhow::anyhow!("invalid origin '{}': {e}", origin))?,
132 )
133 .allow_methods([
134 axum::http::Method::GET,
135 axum::http::Method::POST,
136 axum::http::Method::OPTIONS,
137 ])
138 .allow_headers([
139 axum::http::header::AUTHORIZATION,
140 axum::http::header::CONTENT_TYPE,
141 ]);
142 router = router.layer(cors);
143 }
144
145 tracing::info!("serve listening on {addr}");
146 let listener = tokio::net::TcpListener::bind(addr).await?;
147 axum::serve(listener, router).await?;
148 Ok(())
149}
150
151async fn auth_middleware(
154 State(state): State<Arc<AppState>>,
155 request: Request,
156 next: Next,
157) -> AxumResponse {
158 if let Some(ref expected) = state.token {
159 let auth_header = request
160 .headers()
161 .get(axum::http::header::AUTHORIZATION)
162 .and_then(|v| v.to_str().ok());
163
164 let valid = match auth_header {
165 Some(h) if h.starts_with("Bearer ") => &h[7..] == expected.as_str(),
166 _ => false,
167 };
168
169 if !valid {
170 return err_resp(
171 StatusCode::UNAUTHORIZED,
172 "unauthorized",
173 "missing or invalid Bearer token",
174 );
175 }
176 }
177
178 next.run(request).await
179}
180
181fn error_json(code: &str, message: &str) -> serde_json::Value {
184 serde_json::json!({
185 "schema_version": SCHEMA_VERSION,
186 "ok": false,
187 "type": "error",
188 "error": {
189 "code": code,
190 "message": message,
191 "retryable": false
192 }
193 })
194}
195
196fn err_resp(status: StatusCode, code: &str, message: &str) -> AxumResponse {
197 (status, Json(error_json(code, message))).into_response()
198}
199
200fn map_err_to_response(e: anyhow::Error) -> AxumResponse {
201 if e.downcast_ref::<JobNotFound>().is_some() {
202 err_resp(StatusCode::NOT_FOUND, "job_not_found", &format!("{e:#}"))
203 } else if let Some(amb) = e.downcast_ref::<crate::jobstore::AmbiguousJobId>() {
204 let truncated = amb.candidates.len() > 20;
205 let candidates: Vec<&str> = amb.candidates.iter().take(20).map(|s| s.as_str()).collect();
206 let mut json = error_json("ambiguous_job_id", &format!("{e:#}"));
207 json["error"]["details"] = serde_json::json!({
208 "candidates": candidates,
209 "truncated": truncated,
210 });
211 (StatusCode::BAD_REQUEST, Json(json)).into_response()
212 } else if e
213 .downcast_ref::<crate::jobstore::InvalidJobState>()
214 .is_some()
215 {
216 err_resp(StatusCode::BAD_REQUEST, "invalid_state", &format!("{e:#}"))
217 } else {
218 err_resp(
219 StatusCode::INTERNAL_SERVER_ERROR,
220 "internal_error",
221 &format!("{e:#}"),
222 )
223 }
224}
225
226async fn health_handler() -> impl IntoResponse {
229 let resp = serde_json::json!({
230 "schema_version": SCHEMA_VERSION,
231 "ok": true,
232 "type": "health"
233 });
234 (StatusCode::OK, Json(resp))
235}
236
237#[derive(Deserialize)]
240#[serde(deny_unknown_fields)]
241struct ExecRequest {
242 command: Option<Vec<String>>,
243 cwd: Option<String>,
244 env: Option<HashMap<String, String>>,
245 timeout: Option<f64>,
246 wait: Option<bool>,
247 until: Option<u64>,
248 max_bytes: Option<u64>,
249}
250
251async fn exec_handler(State(state): State<Arc<AppState>>, request: Request) -> AxumResponse {
252 let body_bytes = match axum::body::to_bytes(request.into_body(), 1024 * 1024).await {
254 Ok(b) => b,
255 Err(_) => {
256 return err_resp(
257 StatusCode::BAD_REQUEST,
258 "invalid_request",
259 "failed to read request body",
260 );
261 }
262 };
263
264 if body_bytes.is_empty() {
265 return err_resp(
266 StatusCode::BAD_REQUEST,
267 "invalid_request",
268 "request body is required",
269 );
270 }
271
272 let req: ExecRequest = match serde_json::from_slice(&body_bytes) {
273 Ok(r) => r,
274 Err(e) => {
275 return err_resp(
276 StatusCode::BAD_REQUEST,
277 "invalid_request",
278 &format!("invalid JSON: {e}"),
279 );
280 }
281 };
282
283 let command = match req.command {
284 Some(c) if !c.is_empty() => c,
285 _ => {
286 return err_resp(
287 StatusCode::BAD_REQUEST,
288 "invalid_request",
289 "command field is required and must be non-empty",
290 );
291 }
292 };
293
294 let root_opt = state.root.clone();
295 let env_vars: Vec<String> = req
296 .env
297 .unwrap_or_default()
298 .into_iter()
299 .map(|(k, v)| format!("{k}={v}"))
300 .collect();
301 let cwd = req.cwd;
302 let timeout_ms = req.timeout.map(|s| (s * 1000.0) as u64).unwrap_or(0);
303 let wait = req.wait.unwrap_or(true);
304 let until = req.until.unwrap_or(10);
305 let max_bytes = req.max_bytes.unwrap_or(65536);
306
307 let result = tokio::task::spawn_blocking(move || {
308 run_exec_inner(ExecParams {
309 root: root_opt,
310 command,
311 cwd,
312 env_vars,
313 timeout_ms,
314 wait,
315 until,
316 max_bytes,
317 })
318 })
319 .await;
320
321 match result {
322 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
323 Ok(Err(e)) => map_err_to_response(e),
324 Err(e) => err_resp(
325 StatusCode::INTERNAL_SERVER_ERROR,
326 "internal_error",
327 &format!("task error: {e}"),
328 ),
329 }
330}
331
332struct ExecParams {
333 root: Option<String>,
334 command: Vec<String>,
335 cwd: Option<String>,
336 env_vars: Vec<String>,
337 timeout_ms: u64,
338 wait: bool,
339 until: u64,
340 max_bytes: u64,
341}
342
343fn run_exec_inner(p: ExecParams) -> Result<serde_json::Value> {
344 use crate::run::{
345 SpawnSupervisorParams, now_rfc3339_pub, observe_inline_output, pre_create_log_files,
346 resolve_effective_cwd, spawn_supervisor_process,
347 };
348
349 let elapsed_start = std::time::Instant::now();
350 let resolved_root = resolve_root(p.root.as_deref());
351 std::fs::create_dir_all(&resolved_root)
352 .map_err(|e| anyhow::anyhow!("create jobs root: {e}"))?;
353
354 let job_id = generate_job_id(&resolved_root)?;
355 let created_at = now_rfc3339_pub();
356 let effective_cwd = resolve_effective_cwd(p.cwd.as_deref());
357 let shell_wrapper = crate::config::default_shell_wrapper();
358
359 let env_keys: Vec<String> = p
360 .env_vars
361 .iter()
362 .map(|kv| kv.split('=').next().unwrap_or(kv).to_string())
363 .collect();
364
365 let meta = JobMeta {
366 job: JobMetaJob { id: job_id.clone() },
367 schema_version: SCHEMA_VERSION.to_string(),
368 command: p.command.clone(),
369 created_at,
370 root: resolved_root.display().to_string(),
371 env_keys,
372 env_vars: p.env_vars.clone(),
373 env_vars_runtime: vec![],
374 mask: vec![],
375 cwd: Some(effective_cwd),
376 notification: None,
377 inherit_env: true,
378 env_files: vec![],
379 timeout_ms: p.timeout_ms,
380 kill_after_ms: 0,
381 progress_every_ms: 0,
382 shell_wrapper: Some(shell_wrapper.clone()),
383 stdin_file: None,
384 tags: vec![],
385 };
386
387 let job_dir = JobDir::create(&resolved_root, &job_id, &meta)?;
388 pre_create_log_files(&job_dir)?;
389
390 spawn_supervisor_process(
391 &job_dir,
392 SpawnSupervisorParams {
393 job_id: job_id.clone(),
394 root: resolved_root.clone(),
395 full_log_path: job_dir.full_log_path().display().to_string(),
396 timeout_ms: p.timeout_ms,
397 kill_after_ms: 0,
398 cwd: p.cwd.clone(),
399 env_vars: p.env_vars.clone(),
400 env_files: vec![],
401 inherit_env: true,
402 stdin_file: None,
403 progress_every_ms: 0,
404 notify_command: None,
405 notify_file: None,
406 shell_wrapper: shell_wrapper.clone(),
407 command: p.command.clone(),
408 },
409 )?;
410
411 let stdout_log_path = job_dir.stdout_path().display().to_string();
412 let stderr_log_path = job_dir.stderr_path().display().to_string();
413
414 let observation = observe_inline_output(&job_dir, p.wait, p.until, false, p.max_bytes)?;
415
416 let elapsed_ms = elapsed_start.elapsed().as_millis() as u64;
417
418 let response = Response::new(
419 "run",
420 RunData {
421 job_id,
422 state: observation.state,
423 tags: vec![],
424 env_vars: vec![],
425 stdout_log_path,
426 stderr_log_path,
427 elapsed_ms,
428 waited_ms: observation.waited_ms,
429 stdout: observation.stdout,
430 stderr: observation.stderr,
431 stdout_range: observation.stdout_range,
432 stderr_range: observation.stderr_range,
433 stdout_total_bytes: observation.stdout_total_bytes,
434 stderr_total_bytes: observation.stderr_total_bytes,
435 encoding: observation.encoding,
436 exit_code: observation.exit_code,
437 finished_at: observation.finished_at,
438 signal: observation.signal,
439 duration_ms: observation.duration_ms,
440 },
441 );
442
443 Ok(serde_json::to_value(&response)?)
444}
445
446async fn status_handler(
449 State(state): State<Arc<AppState>>,
450 Path(id): Path<String>,
451) -> AxumResponse {
452 let root_opt = state.root.clone();
453 let result = tokio::task::spawn_blocking(move || {
454 let root = resolve_root(root_opt.as_deref());
455 let job_dir = JobDir::open(&root, &id)?;
456 let meta = job_dir.read_meta()?;
457 let st = job_dir.read_state()?;
458 let response = Response::new(
459 "status",
460 StatusData {
461 job_id: job_dir.job_id.clone(),
462 state: st.status().as_str().to_string(),
463 exit_code: st.exit_code(),
464 created_at: meta.created_at,
465 started_at: st.started_at().map(|s| s.to_string()),
466 finished_at: st.finished_at,
467 },
468 );
469 Ok::<_, anyhow::Error>(serde_json::to_value(&response)?)
470 })
471 .await;
472
473 match result {
474 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
475 Ok(Err(e)) => map_err_to_response(e),
476 Err(e) => err_resp(
477 StatusCode::INTERNAL_SERVER_ERROR,
478 "internal_error",
479 &format!("task error: {e}"),
480 ),
481 }
482}
483
484async fn tail_handler(State(state): State<Arc<AppState>>, Path(id): Path<String>) -> AxumResponse {
487 let root_opt = state.root.clone();
488 let result = tokio::task::spawn_blocking(move || {
489 let root = resolve_root(root_opt.as_deref());
490 let job_dir = JobDir::open(&root, &id)?;
491 let stdout_log_path = job_dir.stdout_path();
492 let stderr_log_path = job_dir.stderr_path();
493 let stdout = job_dir.read_tail_metrics("stdout.log", 50, 65536);
494 let stderr = job_dir.read_tail_metrics("stderr.log", 50, 65536);
495 let response = Response::new(
496 "tail",
497 TailData {
498 job_id: job_dir.job_id.clone(),
499 stdout: stdout.tail,
500 stderr: stderr.tail,
501 encoding: "utf-8-lossy".to_string(),
502 stdout_log_path: stdout_log_path.display().to_string(),
503 stderr_log_path: stderr_log_path.display().to_string(),
504 stdout_range: stdout.range,
505 stderr_range: stderr.range,
506 stdout_total_bytes: stdout.observed_bytes,
507 stderr_total_bytes: stderr.observed_bytes,
508 },
509 );
510 Ok::<_, anyhow::Error>(serde_json::to_value(&response)?)
511 })
512 .await;
513
514 match result {
515 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
516 Ok(Err(e)) => map_err_to_response(e),
517 Err(e) => err_resp(
518 StatusCode::INTERNAL_SERVER_ERROR,
519 "internal_error",
520 &format!("task error: {e}"),
521 ),
522 }
523}
524
525async fn wait_handler(State(state): State<Arc<AppState>>, Path(id): Path<String>) -> AxumResponse {
528 let root_opt = state.root.clone();
529 let result = tokio::task::spawn_blocking(move || {
530 let root = resolve_root(root_opt.as_deref());
531 let job_dir = JobDir::open(&root, &id)?;
532 let poll = std::time::Duration::from_millis(200);
533 loop {
534 let st = job_dir.read_state()?;
535 if !st.status().is_non_terminal() {
536 let response = Response::new("wait", crate::wait::build_wait_data(&job_dir, &st));
537 return Ok::<_, anyhow::Error>(serde_json::to_value(&response)?);
538 }
539 std::thread::sleep(poll);
540 }
541 })
542 .await;
543
544 match result {
545 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
546 Ok(Err(e)) => map_err_to_response(e),
547 Err(e) => err_resp(
548 StatusCode::INTERNAL_SERVER_ERROR,
549 "internal_error",
550 &format!("task error: {e}"),
551 ),
552 }
553}
554
555#[derive(Deserialize)]
558struct KillQuery {
559 #[serde(default)]
560 no_wait: Option<bool>,
561}
562
563async fn kill_handler(
564 State(state): State<Arc<AppState>>,
565 Path(id): Path<String>,
566 Query(query): Query<KillQuery>,
567) -> AxumResponse {
568 let root_opt = state.root.clone();
569 let no_wait = query.no_wait.unwrap_or(false);
570 let result = tokio::task::spawn_blocking(move || {
571 let data = crate::kill::execute_inner(crate::kill::KillOpts {
572 job_id: &id,
573 root: root_opt.as_deref(),
574 signal: "TERM",
575 no_wait,
576 })?;
577 let response = Response::new("kill", data);
578 Ok::<_, anyhow::Error>(serde_json::to_value(&response)?)
579 })
580 .await;
581
582 match result {
583 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
584 Ok(Err(e)) => map_err_to_response(e),
585 Err(e) => err_resp(
586 StatusCode::INTERNAL_SERVER_ERROR,
587 "internal_error",
588 &format!("task error: {e}"),
589 ),
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596 use tower::ServiceExt as _;
597
598 #[test]
599 fn test_is_loopback_ipv4_localhost() {
600 let addr: std::net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
601 assert!(is_loopback(&addr));
602 }
603
604 #[test]
605 fn test_is_loopback_ipv4_127_range() {
606 let addr: std::net::SocketAddr = "127.0.0.2:8080".parse().unwrap();
607 assert!(is_loopback(&addr));
608 }
609
610 #[test]
611 fn test_is_loopback_ipv6() {
612 let addr: std::net::SocketAddr = "[::1]:8080".parse().unwrap();
613 assert!(is_loopback(&addr));
614 }
615
616 #[test]
617 fn test_not_loopback_wildcard() {
618 let addr: std::net::SocketAddr = "0.0.0.0:8080".parse().unwrap();
619 assert!(!is_loopback(&addr));
620 }
621
622 #[test]
623 fn test_not_loopback_external() {
624 let addr: std::net::SocketAddr = "192.168.1.1:8080".parse().unwrap();
625 assert!(!is_loopback(&addr));
626 }
627
628 #[test]
629 fn test_not_loopback_ipv6_all() {
630 let addr: std::net::SocketAddr = "[::]:8080".parse().unwrap();
631 assert!(!is_loopback(&addr));
632 }
633
634 #[test]
635 fn test_error_json_structure() {
636 let val = error_json("test_code", "test message");
637 assert_eq!(val["ok"], false);
638 assert_eq!(val["error"]["code"], "test_code");
639 assert_eq!(val["error"]["message"], "test message");
640 assert_eq!(val["type"], "error");
641 }
642
643 fn test_app(token: Option<&str>) -> Router {
644 let state = Arc::new(AppState {
645 root: None,
646 token: token.map(|t| t.to_string()),
647 allow_origin: None,
648 });
649 Router::new()
650 .route("/test", get(|| async { "ok" }))
651 .layer(middleware::from_fn_with_state(
652 state.clone(),
653 auth_middleware,
654 ))
655 .with_state(state)
656 }
657
658 fn req(uri: &str, auth: Option<&str>) -> axum::http::Request<axum::body::Body> {
659 let mut b = axum::http::Request::builder().uri(uri);
660 if let Some(a) = auth {
661 b = b.header("Authorization", a);
662 }
663 b.body(axum::body::Body::empty()).unwrap()
664 }
665
666 #[tokio::test]
667 async fn test_auth_middleware_no_token_configured() {
668 let resp = test_app(None).oneshot(req("/test", None)).await.unwrap();
669 assert_eq!(resp.status(), StatusCode::OK);
670 }
671
672 #[tokio::test]
673 async fn test_auth_middleware_valid_token() {
674 let resp = test_app(Some("secret123"))
675 .oneshot(req("/test", Some("Bearer secret123")))
676 .await
677 .unwrap();
678 assert_eq!(resp.status(), StatusCode::OK);
679 }
680
681 #[tokio::test]
682 async fn test_auth_middleware_missing_token() {
683 let resp = test_app(Some("secret123"))
684 .oneshot(req("/test", None))
685 .await
686 .unwrap();
687 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
688 }
689
690 #[tokio::test]
691 async fn test_auth_middleware_wrong_token() {
692 let resp = test_app(Some("secret123"))
693 .oneshot(req("/test", Some("Bearer wrong")))
694 .await
695 .unwrap();
696 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
697 }
698
699 #[tokio::test]
700 async fn test_auth_middleware_non_bearer_scheme() {
701 let resp = test_app(Some("secret123"))
702 .oneshot(req("/test", Some("Basic secret123")))
703 .await
704 .unwrap();
705 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
706 }
707
708 #[test]
709 fn exec_request_deserializes_new_fields() {
710 let json =
711 r#"{"command":["echo","hi"],"wait":false,"until":5,"max_bytes":1024,"timeout":30.5}"#;
712 let req: ExecRequest = serde_json::from_str(json).unwrap();
713 assert_eq!(req.wait, Some(false));
714 assert_eq!(req.until, Some(5));
715 assert_eq!(req.max_bytes, Some(1024));
716 assert!((req.timeout.unwrap() - 30.5).abs() < f64::EPSILON);
717 }
718
719 #[test]
720 fn exec_request_defaults_when_omitted() {
721 let json = r#"{"command":["echo","hi"]}"#;
722 let req: ExecRequest = serde_json::from_str(json).unwrap();
723 assert_eq!(req.wait, None);
724 assert_eq!(req.until, None);
725 assert_eq!(req.max_bytes, None);
726 assert_eq!(req.timeout, None);
727 }
728
729 #[test]
730 fn exec_request_rejects_timeout_ms() {
731 let json = r#"{"command":["echo","hi"],"timeout_ms":1000}"#;
732 let result = serde_json::from_str::<ExecRequest>(json);
733 assert!(
734 result.is_err(),
735 "timeout_ms should be rejected as unknown field"
736 );
737 }
738}