1use anyhow::Result;
9use axum::{
10 Json, Router,
11 extract::{Path, Query, Request, State},
12 http::StatusCode,
13 middleware::{self, Next},
14 response::{IntoResponse, Response as AxumResponse},
15 routing::{get, post},
16};
17use serde::Deserialize;
18use std::collections::HashMap;
19use std::net::IpAddr;
20use std::sync::Arc;
21
22use crate::jobstore::{JobDir, JobNotFound, generate_job_id, resolve_root};
23use crate::schema::{JobMeta, JobMetaJob, Response, RunData, SCHEMA_VERSION, StatusData, TailData};
24
25pub struct ServeOpts {
27 pub bind: String,
28 pub root: Option<String>,
29 pub insecure: bool,
30 pub allow_origin: Option<String>,
31}
32
33#[derive(Clone)]
34struct AppState {
35 root: Option<String>,
36 token: Option<String>,
37 allow_origin: Option<String>,
38}
39
40pub fn is_loopback(addr: &std::net::SocketAddr) -> bool {
42 match addr.ip() {
43 IpAddr::V4(v4) => v4.is_loopback(),
44 IpAddr::V6(v6) => v6.is_loopback(),
45 }
46}
47
48pub fn execute(opts: ServeOpts) -> Result<()> {
50 let addr: std::net::SocketAddr = opts
51 .bind
52 .parse()
53 .map_err(|e| anyhow::anyhow!("invalid bind address '{}': {e}", opts.bind))?;
54
55 if !is_loopback(&addr) {
56 if !opts.insecure {
57 let err = error_json(
58 "serve_unsafe_bind",
59 &format!("refusing to bind to non-loopback address {addr} without --insecure"),
60 );
61 eprintln!("Error: non-loopback bind address {addr} requires --insecure flag");
62 println!("{}", serde_json::to_string(&err).unwrap());
63 std::process::exit(1);
64 }
65
66 let token = std::env::var("AGENT_EXEC_SERVE_TOKEN").ok();
67 if token.as_ref().is_none_or(|t| t.is_empty()) {
68 let err = error_json(
69 "serve_unsafe_bind",
70 &format!(
71 "refusing to bind to non-loopback address {addr} without AGENT_EXEC_SERVE_TOKEN"
72 ),
73 );
74 eprintln!(
75 "Error: non-loopback bind address {addr} requires AGENT_EXEC_SERVE_TOKEN to be set"
76 );
77 println!("{}", serde_json::to_string(&err).unwrap());
78 std::process::exit(1);
79 }
80 }
81
82 if let Some(ref origin) = opts.allow_origin
83 && origin == "*"
84 {
85 let err = error_json("invalid_config", "wildcard '*' origin is not allowed");
86 eprintln!("Error: --allow-origin '*' is not permitted");
87 println!("{}", serde_json::to_string(&err).unwrap());
88 std::process::exit(1);
89 }
90
91 let rt = tokio::runtime::Runtime::new()?;
92 rt.block_on(async_main(opts, addr))
93}
94
95async fn async_main(opts: ServeOpts, addr: std::net::SocketAddr) -> Result<()> {
96 let token = std::env::var("AGENT_EXEC_SERVE_TOKEN")
97 .ok()
98 .filter(|t| !t.is_empty());
99
100 let state = Arc::new(AppState {
101 root: opts.root,
102 token,
103 allow_origin: opts.allow_origin,
104 });
105
106 let mutating_routes = Router::new()
107 .route("/exec", post(exec_handler))
108 .route("/kill/{id}", post(kill_handler))
109 .layer(middleware::from_fn_with_state(
110 state.clone(),
111 auth_middleware,
112 ));
113
114 let readonly_routes = Router::new()
115 .route("/health", get(health_handler))
116 .route("/status/{id}", get(status_handler))
117 .route("/tail/{id}", get(tail_handler))
118 .route("/wait/{id}", get(wait_handler));
119
120 let mut router = Router::new()
121 .merge(mutating_routes)
122 .merge(readonly_routes)
123 .with_state(state.clone());
124
125 if let Some(ref origin) = state.allow_origin {
126 use tower_http::cors::CorsLayer;
127 let cors = CorsLayer::new()
128 .allow_origin(
129 origin
130 .parse::<axum::http::HeaderValue>()
131 .map_err(|e| anyhow::anyhow!("invalid origin '{}': {e}", origin))?,
132 )
133 .allow_methods([
134 axum::http::Method::GET,
135 axum::http::Method::POST,
136 axum::http::Method::OPTIONS,
137 ])
138 .allow_headers([
139 axum::http::header::AUTHORIZATION,
140 axum::http::header::CONTENT_TYPE,
141 ]);
142 router = router.layer(cors);
143 }
144
145 tracing::info!("serve listening on {addr}");
146 let listener = tokio::net::TcpListener::bind(addr).await?;
147 axum::serve(listener, router).await?;
148 Ok(())
149}
150
151async fn auth_middleware(
154 State(state): State<Arc<AppState>>,
155 request: Request,
156 next: Next,
157) -> AxumResponse {
158 if let Some(ref expected) = state.token {
159 let auth_header = request
160 .headers()
161 .get(axum::http::header::AUTHORIZATION)
162 .and_then(|v| v.to_str().ok());
163
164 let valid = match auth_header {
165 Some(h) if h.starts_with("Bearer ") => &h[7..] == expected.as_str(),
166 _ => false,
167 };
168
169 if !valid {
170 return err_resp(
171 StatusCode::UNAUTHORIZED,
172 "unauthorized",
173 "missing or invalid Bearer token",
174 );
175 }
176 }
177
178 next.run(request).await
179}
180
181fn error_json(code: &str, message: &str) -> serde_json::Value {
184 serde_json::json!({
185 "schema_version": SCHEMA_VERSION,
186 "ok": false,
187 "type": "error",
188 "error": {
189 "code": code,
190 "message": message,
191 "retryable": false
192 }
193 })
194}
195
196fn err_resp(status: StatusCode, code: &str, message: &str) -> AxumResponse {
197 (status, Json(error_json(code, message))).into_response()
198}
199
200fn map_err_to_response(e: anyhow::Error) -> AxumResponse {
201 if e.downcast_ref::<JobNotFound>().is_some() {
202 err_resp(StatusCode::NOT_FOUND, "job_not_found", &format!("{e:#}"))
203 } else if let Some(amb) = e.downcast_ref::<crate::jobstore::AmbiguousJobId>() {
204 let truncated = amb.candidates.len() > 20;
205 let candidates: Vec<&str> = amb.candidates.iter().take(20).map(|s| s.as_str()).collect();
206 let mut json = error_json("ambiguous_job_id", &format!("{e:#}"));
207 json["error"]["details"] = serde_json::json!({
208 "candidates": candidates,
209 "truncated": truncated,
210 });
211 (StatusCode::BAD_REQUEST, Json(json)).into_response()
212 } else if e
213 .downcast_ref::<crate::jobstore::InvalidJobState>()
214 .is_some()
215 {
216 err_resp(StatusCode::BAD_REQUEST, "invalid_state", &format!("{e:#}"))
217 } else {
218 err_resp(
219 StatusCode::INTERNAL_SERVER_ERROR,
220 "internal_error",
221 &format!("{e:#}"),
222 )
223 }
224}
225
226async fn health_handler() -> impl IntoResponse {
229 let resp = serde_json::json!({
230 "schema_version": SCHEMA_VERSION,
231 "ok": true,
232 "type": "health"
233 });
234 (StatusCode::OK, Json(resp))
235}
236
237#[derive(Deserialize)]
240#[serde(deny_unknown_fields)]
241struct ExecRequest {
242 command: Option<Vec<String>>,
243 cwd: Option<String>,
244 env: Option<HashMap<String, String>>,
245 timeout: Option<f64>,
246 wait: Option<bool>,
247 until: Option<u64>,
248 max_bytes: Option<u64>,
249}
250
251async fn exec_handler(State(state): State<Arc<AppState>>, request: Request) -> AxumResponse {
252 let body_bytes = match axum::body::to_bytes(request.into_body(), 1024 * 1024).await {
254 Ok(b) => b,
255 Err(_) => {
256 return err_resp(
257 StatusCode::BAD_REQUEST,
258 "invalid_request",
259 "failed to read request body",
260 );
261 }
262 };
263
264 if body_bytes.is_empty() {
265 return err_resp(
266 StatusCode::BAD_REQUEST,
267 "invalid_request",
268 "request body is required",
269 );
270 }
271
272 let req: ExecRequest = match serde_json::from_slice(&body_bytes) {
273 Ok(r) => r,
274 Err(e) => {
275 return err_resp(
276 StatusCode::BAD_REQUEST,
277 "invalid_request",
278 &format!("invalid JSON: {e}"),
279 );
280 }
281 };
282
283 let command = match req.command {
284 Some(c) if !c.is_empty() => c,
285 _ => {
286 return err_resp(
287 StatusCode::BAD_REQUEST,
288 "invalid_request",
289 "command field is required and must be non-empty",
290 );
291 }
292 };
293
294 let root_opt = state.root.clone();
295 let env_vars: Vec<String> = req
296 .env
297 .unwrap_or_default()
298 .into_iter()
299 .map(|(k, v)| format!("{k}={v}"))
300 .collect();
301 let cwd = req.cwd;
302 let timeout_ms = req.timeout.map(|s| (s * 1000.0) as u64).unwrap_or(0);
303 let wait = req.wait.unwrap_or(true);
304 let until = req.until.unwrap_or(10);
305 let max_bytes = req.max_bytes.unwrap_or(65536);
306
307 let result = tokio::task::spawn_blocking(move || {
308 run_exec_inner(ExecParams {
309 root: root_opt,
310 command,
311 cwd,
312 env_vars,
313 timeout_ms,
314 wait,
315 until,
316 max_bytes,
317 })
318 })
319 .await;
320
321 match result {
322 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
323 Ok(Err(e)) => map_err_to_response(e),
324 Err(e) => err_resp(
325 StatusCode::INTERNAL_SERVER_ERROR,
326 "internal_error",
327 &format!("task error: {e}"),
328 ),
329 }
330}
331
332struct ExecParams {
333 root: Option<String>,
334 command: Vec<String>,
335 cwd: Option<String>,
336 env_vars: Vec<String>,
337 timeout_ms: u64,
338 wait: bool,
339 until: u64,
340 max_bytes: u64,
341}
342
343fn run_exec_inner(p: ExecParams) -> Result<serde_json::Value> {
344 use crate::run::{
345 SpawnSupervisorParams, now_rfc3339_pub, observe_inline_output, pre_create_log_files,
346 resolve_effective_cwd, spawn_supervisor_process,
347 };
348
349 let elapsed_start = std::time::Instant::now();
350 let resolved_root = resolve_root(p.root.as_deref());
351 std::fs::create_dir_all(&resolved_root)
352 .map_err(|e| anyhow::anyhow!("create jobs root: {e}"))?;
353
354 let job_id = generate_job_id(&resolved_root)?;
355 let created_at = now_rfc3339_pub();
356 let effective_cwd = resolve_effective_cwd(p.cwd.as_deref());
357 let shell_wrapper = crate::config::default_shell_wrapper();
358
359 let env_keys: Vec<String> = p
360 .env_vars
361 .iter()
362 .map(|kv| kv.split('=').next().unwrap_or(kv).to_string())
363 .collect();
364
365 let meta = JobMeta {
366 job: JobMetaJob { id: job_id.clone() },
367 schema_version: SCHEMA_VERSION.to_string(),
368 command: p.command.clone(),
369 created_at,
370 root: resolved_root.display().to_string(),
371 env_keys,
372 env_vars: p.env_vars.clone(),
373 env_vars_runtime: vec![],
374 mask: vec![],
375 cwd: Some(effective_cwd),
376 notification: None,
377 inherit_env: true,
378 env_files: vec![],
379 timeout_ms: p.timeout_ms,
380 kill_after_ms: 0,
381 progress_every_ms: 0,
382 shell_wrapper: Some(shell_wrapper.clone()),
383 stdin_file: None,
384 tags: vec![],
385 };
386
387 let job_dir = JobDir::create(&resolved_root, &job_id, &meta)?;
388 pre_create_log_files(&job_dir)?;
389
390 spawn_supervisor_process(
391 &job_dir,
392 SpawnSupervisorParams {
393 job_id: job_id.clone(),
394 root: resolved_root.clone(),
395 full_log_path: job_dir.full_log_path().display().to_string(),
396 timeout_ms: p.timeout_ms,
397 kill_after_ms: 0,
398 cwd: p.cwd.clone(),
399 env_vars: p.env_vars.clone(),
400 env_files: vec![],
401 inherit_env: true,
402 stdin_file: None,
403 progress_every_ms: 0,
404 notify_command: None,
405 notify_file: None,
406 shell_wrapper: shell_wrapper.clone(),
407 command: p.command.clone(),
408 },
409 )?;
410
411 let stdout_log_path = job_dir.stdout_path().display().to_string();
412 let stderr_log_path = job_dir.stderr_path().display().to_string();
413
414 let observation = observe_inline_output(&job_dir, p.wait, p.until, false, p.max_bytes)?;
415
416 let elapsed_ms = elapsed_start.elapsed().as_millis() as u64;
417
418 let response = Response::new(
419 "run",
420 RunData {
421 job_id,
422 state: observation.state,
423 tags: vec![],
424 env_vars: vec![],
425 stdout_log_path,
426 stderr_log_path,
427 elapsed_ms,
428 waited_ms: observation.waited_ms,
429 stdout: observation.stdout,
430 stderr: observation.stderr,
431 stdout_range: observation.stdout_range,
432 stderr_range: observation.stderr_range,
433 stdout_total_bytes: observation.stdout_total_bytes,
434 stderr_total_bytes: observation.stderr_total_bytes,
435 encoding: observation.encoding,
436 exit_code: observation.exit_code,
437 finished_at: observation.finished_at,
438 signal: observation.signal,
439 duration_ms: observation.duration_ms,
440 compression: None,
441 },
442 );
443
444 Ok(serde_json::to_value(&response)?)
445}
446
447async fn status_handler(
450 State(state): State<Arc<AppState>>,
451 Path(id): Path<String>,
452) -> AxumResponse {
453 let root_opt = state.root.clone();
454 let result = tokio::task::spawn_blocking(move || {
455 let root = resolve_root(root_opt.as_deref());
456 let job_dir = JobDir::open(&root, &id)?;
457 let meta = job_dir.read_meta()?;
458 let st = job_dir.read_state()?;
459 let response = Response::new(
460 "status",
461 StatusData {
462 job_id: job_dir.job_id.clone(),
463 state: st.status().as_str().to_string(),
464 exit_code: st.exit_code(),
465 created_at: meta.created_at,
466 started_at: st.started_at().map(|s| s.to_string()),
467 finished_at: st.finished_at,
468 },
469 );
470 Ok::<_, anyhow::Error>(serde_json::to_value(&response)?)
471 })
472 .await;
473
474 match result {
475 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
476 Ok(Err(e)) => map_err_to_response(e),
477 Err(e) => err_resp(
478 StatusCode::INTERNAL_SERVER_ERROR,
479 "internal_error",
480 &format!("task error: {e}"),
481 ),
482 }
483}
484
485async fn tail_handler(State(state): State<Arc<AppState>>, Path(id): Path<String>) -> AxumResponse {
488 let root_opt = state.root.clone();
489 let result = tokio::task::spawn_blocking(move || {
490 let root = resolve_root(root_opt.as_deref());
491 let job_dir = JobDir::open(&root, &id)?;
492 let stdout_log_path = job_dir.stdout_path();
493 let stderr_log_path = job_dir.stderr_path();
494 let stdout = job_dir.read_tail_metrics("stdout.log", 50, 65536);
495 let stderr = job_dir.read_tail_metrics("stderr.log", 50, 65536);
496 let response = Response::new(
497 "tail",
498 TailData {
499 job_id: job_dir.job_id.clone(),
500 stdout: stdout.tail,
501 stderr: stderr.tail,
502 encoding: "utf-8-lossy".to_string(),
503 stdout_log_path: stdout_log_path.display().to_string(),
504 stderr_log_path: stderr_log_path.display().to_string(),
505 stdout_range: stdout.range,
506 stderr_range: stderr.range,
507 stdout_total_bytes: stdout.observed_bytes,
508 stderr_total_bytes: stderr.observed_bytes,
509 compression: None,
510 },
511 );
512 Ok::<_, anyhow::Error>(serde_json::to_value(&response)?)
513 })
514 .await;
515
516 match result {
517 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
518 Ok(Err(e)) => map_err_to_response(e),
519 Err(e) => err_resp(
520 StatusCode::INTERNAL_SERVER_ERROR,
521 "internal_error",
522 &format!("task error: {e}"),
523 ),
524 }
525}
526
527async fn wait_handler(State(state): State<Arc<AppState>>, Path(id): Path<String>) -> AxumResponse {
530 let root_opt = state.root.clone();
531 let result = tokio::task::spawn_blocking(move || {
532 let root = resolve_root(root_opt.as_deref());
533 let job_dir = JobDir::open(&root, &id)?;
534 let poll = std::time::Duration::from_millis(200);
535 loop {
536 let st = job_dir.read_state()?;
537 if !st.status().is_non_terminal() {
538 let response = Response::new("wait", crate::wait::build_wait_data(&job_dir, &st));
539 return Ok::<_, anyhow::Error>(serde_json::to_value(&response)?);
540 }
541 std::thread::sleep(poll);
542 }
543 })
544 .await;
545
546 match result {
547 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
548 Ok(Err(e)) => map_err_to_response(e),
549 Err(e) => err_resp(
550 StatusCode::INTERNAL_SERVER_ERROR,
551 "internal_error",
552 &format!("task error: {e}"),
553 ),
554 }
555}
556
557#[derive(Deserialize)]
560struct KillQuery {
561 #[serde(default)]
562 no_wait: Option<bool>,
563}
564
565async fn kill_handler(
566 State(state): State<Arc<AppState>>,
567 Path(id): Path<String>,
568 Query(query): Query<KillQuery>,
569) -> AxumResponse {
570 let root_opt = state.root.clone();
571 let no_wait = query.no_wait.unwrap_or(false);
572 let result = tokio::task::spawn_blocking(move || {
573 let data = crate::kill::execute_inner(crate::kill::KillOpts {
574 job_id: &id,
575 root: root_opt.as_deref(),
576 signal: "TERM",
577 no_wait,
578 })?;
579 let response = Response::new("kill", data);
580 Ok::<_, anyhow::Error>(serde_json::to_value(&response)?)
581 })
582 .await;
583
584 match result {
585 Ok(Ok(val)) => (StatusCode::OK, Json(val)).into_response(),
586 Ok(Err(e)) => map_err_to_response(e),
587 Err(e) => err_resp(
588 StatusCode::INTERNAL_SERVER_ERROR,
589 "internal_error",
590 &format!("task error: {e}"),
591 ),
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use tower::ServiceExt as _;
599
600 #[test]
601 fn test_is_loopback_ipv4_localhost() {
602 let addr: std::net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
603 assert!(is_loopback(&addr));
604 }
605
606 #[test]
607 fn test_is_loopback_ipv4_127_range() {
608 let addr: std::net::SocketAddr = "127.0.0.2:8080".parse().unwrap();
609 assert!(is_loopback(&addr));
610 }
611
612 #[test]
613 fn test_is_loopback_ipv6() {
614 let addr: std::net::SocketAddr = "[::1]:8080".parse().unwrap();
615 assert!(is_loopback(&addr));
616 }
617
618 #[test]
619 fn test_not_loopback_wildcard() {
620 let addr: std::net::SocketAddr = "0.0.0.0:8080".parse().unwrap();
621 assert!(!is_loopback(&addr));
622 }
623
624 #[test]
625 fn test_not_loopback_external() {
626 let addr: std::net::SocketAddr = "192.168.1.1:8080".parse().unwrap();
627 assert!(!is_loopback(&addr));
628 }
629
630 #[test]
631 fn test_not_loopback_ipv6_all() {
632 let addr: std::net::SocketAddr = "[::]:8080".parse().unwrap();
633 assert!(!is_loopback(&addr));
634 }
635
636 #[test]
637 fn test_error_json_structure() {
638 let val = error_json("test_code", "test message");
639 assert_eq!(val["ok"], false);
640 assert_eq!(val["error"]["code"], "test_code");
641 assert_eq!(val["error"]["message"], "test message");
642 assert_eq!(val["type"], "error");
643 }
644
645 fn test_app(token: Option<&str>) -> Router {
646 let state = Arc::new(AppState {
647 root: None,
648 token: token.map(|t| t.to_string()),
649 allow_origin: None,
650 });
651 Router::new()
652 .route("/test", get(|| async { "ok" }))
653 .layer(middleware::from_fn_with_state(
654 state.clone(),
655 auth_middleware,
656 ))
657 .with_state(state)
658 }
659
660 fn req(uri: &str, auth: Option<&str>) -> axum::http::Request<axum::body::Body> {
661 let mut b = axum::http::Request::builder().uri(uri);
662 if let Some(a) = auth {
663 b = b.header("Authorization", a);
664 }
665 b.body(axum::body::Body::empty()).unwrap()
666 }
667
668 #[tokio::test]
669 async fn test_auth_middleware_no_token_configured() {
670 let resp = test_app(None).oneshot(req("/test", None)).await.unwrap();
671 assert_eq!(resp.status(), StatusCode::OK);
672 }
673
674 #[tokio::test]
675 async fn test_auth_middleware_valid_token() {
676 let resp = test_app(Some("secret123"))
677 .oneshot(req("/test", Some("Bearer secret123")))
678 .await
679 .unwrap();
680 assert_eq!(resp.status(), StatusCode::OK);
681 }
682
683 #[tokio::test]
684 async fn test_auth_middleware_missing_token() {
685 let resp = test_app(Some("secret123"))
686 .oneshot(req("/test", None))
687 .await
688 .unwrap();
689 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
690 }
691
692 #[tokio::test]
693 async fn test_auth_middleware_wrong_token() {
694 let resp = test_app(Some("secret123"))
695 .oneshot(req("/test", Some("Bearer wrong")))
696 .await
697 .unwrap();
698 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
699 }
700
701 #[tokio::test]
702 async fn test_auth_middleware_non_bearer_scheme() {
703 let resp = test_app(Some("secret123"))
704 .oneshot(req("/test", Some("Basic secret123")))
705 .await
706 .unwrap();
707 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
708 }
709
710 #[test]
711 fn exec_request_deserializes_new_fields() {
712 let json =
713 r#"{"command":["echo","hi"],"wait":false,"until":5,"max_bytes":1024,"timeout":30.5}"#;
714 let req: ExecRequest = serde_json::from_str(json).unwrap();
715 assert_eq!(req.wait, Some(false));
716 assert_eq!(req.until, Some(5));
717 assert_eq!(req.max_bytes, Some(1024));
718 assert!((req.timeout.unwrap() - 30.5).abs() < f64::EPSILON);
719 }
720
721 #[test]
722 fn exec_request_defaults_when_omitted() {
723 let json = r#"{"command":["echo","hi"]}"#;
724 let req: ExecRequest = serde_json::from_str(json).unwrap();
725 assert_eq!(req.wait, None);
726 assert_eq!(req.until, None);
727 assert_eq!(req.max_bytes, None);
728 assert_eq!(req.timeout, None);
729 }
730
731 #[test]
732 fn exec_request_rejects_timeout_ms() {
733 let json = r#"{"command":["echo","hi"],"timeout_ms":1000}"#;
734 let result = serde_json::from_str::<ExecRequest>(json);
735 assert!(
736 result.is_err(),
737 "timeout_ms should be rejected as unknown field"
738 );
739 }
740}