1use super::{ExecutionError, StageExecutor};
39use noether_core::stage::StageId;
40use serde_json::Value;
41use sha2::{Digest, Sha256};
42use std::collections::HashMap;
43use std::io::Write as IoWrite;
44use std::path::{Path, PathBuf};
45use std::process::{Command, Stdio};
46use std::sync::mpsc;
47use std::time::Duration;
48
49#[derive(Debug, Clone)]
53pub struct NixConfig {
54 pub timeout_secs: u64,
58 pub max_output_bytes: usize,
62 pub max_stderr_bytes: usize,
65}
66
67impl Default for NixConfig {
68 fn default() -> Self {
69 Self {
70 timeout_secs: 30,
71 max_output_bytes: 10 * 1024 * 1024,
72 max_stderr_bytes: 64 * 1024,
73 }
74 }
75}
76
77#[derive(Clone)]
81struct StageImpl {
82 code: String,
83 language: String,
84}
85
86pub struct NixExecutor {
100 nix_bin: PathBuf,
101 cache_dir: PathBuf,
102 config: NixConfig,
103 implementations: HashMap<String, StageImpl>,
104}
105
106impl NixExecutor {
107 pub fn find_nix() -> Option<PathBuf> {
110 let determinate = PathBuf::from("/nix/var/nix/profiles/default/bin/nix");
112 if determinate.exists() {
113 return Some(determinate);
114 }
115 if let Ok(output) = Command::new("which").arg("nix").output() {
117 let p = std::str::from_utf8(&output.stdout)
118 .unwrap_or("")
119 .trim()
120 .to_string();
121 if !p.is_empty() {
122 return Some(PathBuf::from(p));
123 }
124 }
125 None
126 }
127
128 pub fn from_store(store: &dyn noether_store::StageStore) -> Option<Self> {
133 Self::from_store_with_config(store, NixConfig::default())
134 }
135
136 pub fn from_store_with_config(
138 store: &dyn noether_store::StageStore,
139 config: NixConfig,
140 ) -> Option<Self> {
141 let nix_bin = Self::find_nix()?;
142
143 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
144 let cache_dir = PathBuf::from(home).join(".noether").join("impl_cache");
145 let _ = std::fs::create_dir_all(&cache_dir);
146
147 let mut implementations = HashMap::new();
148 for stage in store.list(None) {
149 if let (Some(code), Some(lang)) =
150 (&stage.implementation_code, &stage.implementation_language)
151 {
152 implementations.insert(
153 stage.id.0.clone(),
154 StageImpl {
155 code: code.clone(),
156 language: lang.clone(),
157 },
158 );
159 }
160 }
161
162 Some(Self {
163 nix_bin,
164 cache_dir,
165 config,
166 implementations,
167 })
168 }
169
170 pub fn register(&mut self, stage_id: &StageId, code: &str, language: &str) {
172 self.implementations.insert(
173 stage_id.0.clone(),
174 StageImpl {
175 code: code.into(),
176 language: language.into(),
177 },
178 );
179 }
180
181 pub fn has_implementation(&self, stage_id: &StageId) -> bool {
183 self.implementations.contains_key(&stage_id.0)
184 }
185
186 pub fn warmup(&self) -> std::thread::JoinHandle<()> {
196 let nix_bin = self.nix_bin.clone();
197 std::thread::spawn(move || {
198 let status = Command::new(&nix_bin)
201 .args([
202 "build",
203 "--no-link",
204 "--quiet",
205 "--no-write-lock-file",
206 "nixpkgs#python3",
207 ])
208 .stdout(Stdio::null())
209 .stderr(Stdio::null())
210 .status();
211 match status {
212 Ok(s) if s.success() => {
213 eprintln!("[noether] nix warmup: python3 runtime cached");
214 }
215 Ok(s) => {
216 eprintln!("[noether] nix warmup: exited with {s} (non-fatal)");
217 }
218 Err(e) => {
219 eprintln!("[noether] nix warmup: failed to spawn ({e}) (non-fatal)");
220 }
221 }
222 })
223 }
224
225 fn code_hash(code: &str) -> String {
229 hex::encode(Sha256::digest(code.as_bytes()))
230 }
231
232 fn ensure_script(
235 &self,
236 impl_hash: &str,
237 code: &str,
238 language: &str,
239 ) -> Result<PathBuf, ExecutionError> {
240 let ext = match language {
241 "javascript" | "js" => "js",
242 "bash" | "sh" => "sh",
243 _ => "py",
244 };
245
246 let path = self.cache_dir.join(format!("{impl_hash}.{ext}"));
247 if path.exists() {
248 return Ok(path);
249 }
250
251 let wrapped = match language {
252 "javascript" | "js" => Self::wrap_javascript(code),
253 "bash" | "sh" => Self::wrap_bash(code),
254 _ => Self::wrap_python(code),
255 };
256
257 std::fs::write(&path, &wrapped).map_err(|e| ExecutionError::StageFailed {
258 stage_id: StageId(impl_hash.into()),
259 message: format!("failed to write stage script: {e}"),
260 })?;
261
262 Ok(path)
263 }
264
265 fn run_script(
268 &self,
269 stage_id: &StageId,
270 script: &Path,
271 language: &str,
272 input: &Value,
273 ) -> Result<Value, ExecutionError> {
274 let input_json = serde_json::to_string(input).unwrap_or_default();
275
276 let code = self
277 .implementations
278 .get(&stage_id.0)
279 .map(|i| i.code.as_str())
280 .unwrap_or("");
281
282 let (nix_subcommand, args) = self.build_nix_command(language, script, code);
283
284 let mut child = if nix_subcommand == "__direct__" {
286 Command::new(&args[0])
287 .args(&args[1..])
288 .stdin(Stdio::piped())
289 .stdout(Stdio::piped())
290 .stderr(Stdio::piped())
291 .spawn()
292 } else {
293 Command::new(&self.nix_bin)
294 .arg(&nix_subcommand)
295 .args(["--no-write-lock-file", "--quiet"])
296 .args(&args)
297 .stdin(Stdio::piped())
298 .stdout(Stdio::piped())
299 .stderr(Stdio::piped())
300 .spawn()
301 }
302 .map_err(|e| ExecutionError::StageFailed {
303 stage_id: stage_id.clone(),
304 message: format!("failed to spawn process: {e}"),
305 })?;
306
307 if let Some(mut stdin) = child.stdin.take() {
310 let bytes = input_json.into_bytes();
311 std::thread::spawn(move || {
312 let _ = stdin.write_all(&bytes);
313 });
314 }
315
316 let pid = child.id();
318 let timeout = Duration::from_secs(self.config.timeout_secs);
319 let (tx, rx) = mpsc::channel();
320 std::thread::spawn(move || {
321 let _ = tx.send(child.wait_with_output());
322 });
323
324 let out = match rx.recv_timeout(timeout) {
325 Ok(Ok(o)) => o,
326 Ok(Err(e)) => {
327 return Err(ExecutionError::StageFailed {
328 stage_id: stage_id.clone(),
329 message: format!("nix process error: {e}"),
330 });
331 }
332 Err(_elapsed) => {
333 let _ = Command::new("kill").args(["-9", &pid.to_string()]).status();
335 return Err(ExecutionError::TimedOut {
336 stage_id: stage_id.clone(),
337 timeout_secs: self.config.timeout_secs,
338 });
339 }
340 };
341
342 let stderr_raw = &out.stderr[..out.stderr.len().min(self.config.max_stderr_bytes)];
344 let stderr = String::from_utf8_lossy(stderr_raw);
345
346 if !out.status.success() {
347 return Err(ExecutionError::StageFailed {
348 stage_id: stage_id.clone(),
349 message: Self::classify_error(&stderr, out.status.code()),
350 });
351 }
352
353 let stdout_raw = &out.stdout[..out.stdout.len().min(self.config.max_output_bytes)];
355 let stdout = String::from_utf8_lossy(stdout_raw);
356
357 if stdout_raw.len() == self.config.max_output_bytes && !out.stdout.is_empty() {
358 return Err(ExecutionError::StageFailed {
359 stage_id: stage_id.clone(),
360 message: format!(
361 "stage output exceeded {} bytes limit",
362 self.config.max_output_bytes
363 ),
364 });
365 }
366
367 serde_json::from_str(stdout.trim()).map_err(|e| ExecutionError::StageFailed {
368 stage_id: stage_id.clone(),
369 message: format!("failed to parse stage output as JSON: {e} (got: {stdout:?})"),
370 })
371 }
372
373 fn classify_error(stderr: &str, exit_code: Option<i32>) -> String {
376 if stderr.contains("cannot connect to nix daemon")
378 || stderr.contains("Cannot connect to the Nix daemon")
379 {
380 return "nix daemon is not running — start it with `sudo systemctl start nix-daemon` \
381 or `nix daemon`"
382 .to_string();
383 }
384 if stderr.contains("error: flake") || stderr.contains("error: getting flake") {
385 return format!(
386 "nix flake error (check network / nixpkgs access): {}",
387 first_line(stderr)
388 );
389 }
390 if stderr.contains("error: downloading") || stderr.contains("error: fetching") {
391 return format!(
392 "nix failed to fetch runtime package (check network): {}",
393 first_line(stderr)
394 );
395 }
396 if stderr.contains("out of disk space") || stderr.contains("No space left on device") {
397 return "nix store out of disk space — run `nix-collect-garbage -d` to free space"
398 .to_string();
399 }
400 if stderr.contains("nix: command not found") || stderr.contains("No such file") {
401 return "nix binary not found — is Nix installed?".to_string();
402 }
403 let code_str = exit_code
405 .map(|c| format!(" (exit {c})"))
406 .unwrap_or_default();
407 if stderr.trim().is_empty() {
408 format!("stage exited without output{code_str}")
409 } else {
410 format!("stage error{code_str}: {stderr}")
411 }
412 }
413
414 fn build_nix_command(
420 &self,
421 language: &str,
422 script: &Path,
423 code: &str,
424 ) -> (String, Vec<String>) {
425 let script_path = script.to_str().unwrap_or("/dev/null").to_string();
426
427 match language {
428 "python" | "python3" | "" => {
429 if let Some(reqs) = Self::extract_pip_requirements(code) {
433 let venv_hash = {
434 use sha2::{Digest, Sha256};
435 let h = Sha256::digest(reqs.as_bytes());
436 hex::encode(&h[..8])
437 };
438 let venv_dir = self.cache_dir.join(format!("venv-{venv_hash}"));
439 let venv_str = venv_dir.to_string_lossy().to_string();
440 let python = venv_dir.join("bin").join("python3");
441 let python_str = python.to_string_lossy().to_string();
442
443 if !python.exists() {
445 let setup = std::process::Command::new("python3")
446 .args(["-m", "venv", &venv_str])
447 .output();
448 if let Ok(out) = setup {
449 if out.status.success() {
450 let pip = venv_dir.join("bin").join("pip");
451 let pkgs: Vec<&str> = reqs.split(", ").collect();
452 let mut pip_args =
453 vec!["install", "--quiet", "--disable-pip-version-check"];
454 pip_args.extend(pkgs);
455 let _ = std::process::Command::new(pip.to_string_lossy().as_ref())
456 .args(&pip_args)
457 .output();
458 }
459 }
460 }
461
462 return ("__direct__".to_string(), vec![python_str, script_path]);
464 }
465
466 let extra_pkgs = Self::detect_python_packages(code);
467 if extra_pkgs.is_empty() {
468 (
469 "run".to_string(),
470 vec!["nixpkgs#python3".into(), "--".into(), script_path],
471 )
472 } else {
473 let mut args: Vec<String> = extra_pkgs
474 .iter()
475 .map(|pkg| format!("nixpkgs#python3Packages.{pkg}"))
476 .collect();
477 args.extend_from_slice(&["--command".into(), "python3".into(), script_path]);
478 ("shell".to_string(), args)
479 }
480 }
481 "javascript" | "js" => (
482 "run".to_string(),
483 vec!["nixpkgs#nodejs".into(), "--".into(), script_path],
484 ),
485 _ => (
486 "run".to_string(),
487 vec!["nixpkgs#bash".into(), "--".into(), script_path],
488 ),
489 }
490 }
491
492 fn extract_pip_requirements(code: &str) -> Option<String> {
494 for line in code.lines() {
495 let trimmed = line.trim();
496 if trimmed.starts_with("# requires:") {
497 let reqs = trimmed.strip_prefix("# requires:").unwrap().trim();
498 if !reqs.is_empty() {
499 return Some(reqs.to_string());
500 }
501 }
502 }
503 None
504 }
505
506 fn detect_python_packages(code: &str) -> Vec<&'static str> {
509 const KNOWN: &[(&str, &str)] = &[
511 ("requests", "requests"),
512 ("httpx", "httpx"),
513 ("aiohttp", "aiohttp"),
514 ("bs4", "beautifulsoup4"),
515 ("lxml", "lxml"),
516 ("pandas", "pandas"),
517 ("numpy", "numpy"),
518 ("scipy", "scipy"),
519 ("sklearn", "scikit-learn"),
520 ("PIL", "Pillow"),
521 ("cv2", "opencv4"),
522 ("yaml", "pyyaml"),
523 ("toml", "toml"),
524 ("dateutil", "python-dateutil"),
525 ("pytz", "pytz"),
526 ("boto3", "boto3"),
527 ("psycopg2", "psycopg2"),
528 ("pymongo", "pymongo"),
529 ("redis", "redis"),
530 ("celery", "celery"),
531 ("fastapi", "fastapi"),
532 ("pydantic", "pydantic"),
533 ("cryptography", "cryptography"),
534 ("jwt", "pyjwt"),
535 ("paramiko", "paramiko"),
536 ("dotenv", "python-dotenv"),
537 ("joblib", "joblib"),
538 ("torch", "pytorch"),
539 ("transformers", "transformers"),
540 ("datasets", "datasets"),
541 ("pyarrow", "pyarrow"),
542 ];
543
544 let mut found: Vec<&'static str> = Vec::new();
545 for (import_name, nix_name) in KNOWN {
546 let patterns = [
547 format!("import {import_name}"),
548 format!("import {import_name} "),
549 format!("from {import_name} "),
550 format!("from {import_name}."),
551 ];
552 if patterns.iter().any(|p| code.contains(p.as_str())) {
553 found.push(nix_name);
554 }
555 }
556 found
557 }
558
559 fn wrap_python(user_code: &str) -> String {
562 let pip_install = String::new();
566
567 format!(
568 r#"import sys, json as _json
569{pip_install}
570# ---- user implementation ----
571{user_code}
572# ---- end implementation ----
573
574if __name__ == '__main__':
575 if 'execute' not in dir() or not callable(globals().get('execute')):
576 print(
577 "Noether stage error: implementation must define a top-level "
578 "function `def execute(input): ...` that takes the parsed input dict "
579 "and returns the output dict. Do not read from stdin or print to stdout — "
580 "the Noether runtime handles I/O for you.",
581 file=sys.stderr,
582 )
583 sys.exit(1)
584 try:
585 _raw = _json.loads(sys.stdin.read())
586 # If the runtime passed input as a JSON-encoded string, decode it once more.
587 # This happens when input arrives as null or a bare string from the CLI.
588 if isinstance(_raw, str):
589 try:
590 _raw = _json.loads(_raw)
591 except Exception:
592 pass
593 _output = execute(_raw if _raw is not None else {{}})
594 print(_json.dumps(_output))
595 except Exception as _e:
596 print(str(_e), file=sys.stderr)
597 sys.exit(1)
598"#
599 )
600 }
601
602 fn wrap_javascript(user_code: &str) -> String {
603 format!(
604 r#"const _readline = require('readline');
605let _input = '';
606process.stdin.on('data', d => _input += d);
607process.stdin.on('end', () => {{
608 try {{
609 // ---- user implementation ----
610 {user_code}
611 // ---- end implementation ----
612 const _result = execute(JSON.parse(_input));
613 process.stdout.write(JSON.stringify(_result) + '\n');
614 }} catch (e) {{
615 process.stderr.write(String(e) + '\n');
616 process.exit(1);
617 }}
618}});
619"#
620 )
621 }
622
623 fn wrap_bash(user_code: &str) -> String {
624 format!(
625 r#"#!/usr/bin/env bash
626set -euo pipefail
627INPUT=$(cat)
628
629# ---- user implementation ----
630{user_code}
631# ---- end implementation ----
632
633execute "$INPUT"
634"#
635 )
636 }
637}
638
639fn first_line(s: &str) -> &str {
643 s.lines()
644 .map(str::trim)
645 .find(|l| !l.is_empty())
646 .unwrap_or(s)
647}
648
649impl StageExecutor for NixExecutor {
652 fn execute(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
653 let impl_ = self
654 .implementations
655 .get(&stage_id.0)
656 .ok_or_else(|| ExecutionError::StageNotFound(stage_id.clone()))?;
657
658 let code_hash = Self::code_hash(&impl_.code);
659 let script = self.ensure_script(&code_hash, &impl_.code, &impl_.language)?;
660 self.run_script(stage_id, &script, &impl_.language, input)
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 #[allow(dead_code)] fn make_executor() -> NixExecutor {
670 let nix_bin = NixExecutor::find_nix().unwrap_or_else(|| PathBuf::from("/usr/bin/nix"));
671 let cache_dir = std::env::temp_dir().join("noether-test-impl-cache");
672 let _ = std::fs::create_dir_all(&cache_dir);
673 NixExecutor {
674 nix_bin,
675 cache_dir,
676 config: NixConfig::default(),
677 implementations: HashMap::new(),
678 }
679 }
680
681 #[test]
682 fn detect_python_packages_requests() {
683 let code = "import requests\ndef execute(v):\n return requests.get(v).json()";
684 let pkgs = NixExecutor::detect_python_packages(code);
685 assert!(
686 pkgs.contains(&"requests"),
687 "expected 'requests' in {pkgs:?}"
688 );
689 }
690
691 #[test]
692 fn detect_python_packages_stdlib_only() {
693 let code = "import urllib.request, json\ndef execute(v):\n return json.loads(v)";
694 let pkgs = NixExecutor::detect_python_packages(code);
695 assert!(
696 pkgs.is_empty(),
697 "stdlib imports should not trigger packages: {pkgs:?}"
698 );
699 }
700
701 #[test]
702 fn detect_python_packages_multiple() {
703 let code = "import pandas\nimport numpy as np\nfrom bs4 import BeautifulSoup\ndef execute(v): pass";
704 let pkgs = NixExecutor::detect_python_packages(code);
705 assert!(pkgs.contains(&"pandas"));
706 assert!(pkgs.contains(&"numpy"));
707 assert!(pkgs.contains(&"beautifulsoup4"));
708 }
709
710 fn test_executor() -> NixExecutor {
711 NixExecutor {
712 nix_bin: PathBuf::from("/usr/bin/nix"),
713 cache_dir: PathBuf::from("/tmp/noether-test-cache"),
714 config: NixConfig::default(),
715 implementations: HashMap::new(),
716 }
717 }
718
719 #[test]
720 fn build_nix_command_no_packages() {
721 let exec = test_executor();
722 let (sub, args) = exec.build_nix_command("python", Path::new("/tmp/x.py"), "import json");
723 assert_eq!(sub, "run");
724 assert!(args.iter().any(|a| a.contains("python3")));
725 assert!(!args.iter().any(|a| a.contains("shell")));
726 }
727
728 #[test]
729 fn build_nix_command_with_requests() {
730 let exec = test_executor();
731 let (sub, args) =
732 exec.build_nix_command("python", Path::new("/tmp/x.py"), "import requests");
733 assert_eq!(sub, "shell");
734 assert!(args.iter().any(|a| a.contains("python3Packages.requests")));
735 assert!(args.iter().any(|a| a == "--command"));
736 assert!(
738 !args.iter().any(|a| a == "nixpkgs#python3"),
739 "bare python3 conflicts: {args:?}"
740 );
741 }
742
743 #[test]
744 fn python_wrapper_contains_boilerplate() {
745 let wrapped = NixExecutor::wrap_python("def execute(x):\n return x + 1");
746 assert!(wrapped.contains("sys.stdin.read()"));
747 assert!(wrapped.contains("_json.dumps(_output)"));
748 assert!(wrapped.contains("def execute(x)"));
749 }
750
751 #[test]
752 fn code_hash_is_stable() {
753 let h1 = NixExecutor::code_hash("hello world");
754 let h2 = NixExecutor::code_hash("hello world");
755 let h3 = NixExecutor::code_hash("different");
756 assert_eq!(h1, h2);
757 assert_ne!(h1, h3);
758 }
759
760 #[test]
761 fn classify_error_daemon_not_running() {
762 let msg = NixExecutor::classify_error("error: cannot connect to nix daemon", Some(1));
763 assert!(msg.contains("nix daemon is not running"), "got: {msg}");
764 }
765
766 #[test]
767 fn classify_error_user_code_exit1() {
768 let msg = NixExecutor::classify_error("ValueError: invalid input", Some(1));
769 assert!(msg.contains("ValueError"), "got: {msg}");
770 assert!(msg.contains("exit 1"), "got: {msg}");
771 }
772
773 #[test]
774 fn classify_error_disk_full() {
775 let msg = NixExecutor::classify_error("No space left on device", Some(1));
776 assert!(msg.contains("disk space"), "got: {msg}");
777 }
778
779 #[test]
780 fn classify_error_empty_stderr() {
781 let msg = NixExecutor::classify_error("", Some(137));
782 assert!(msg.contains("exit 137"), "got: {msg}");
783 }
784
785 #[test]
786 fn nix_config_defaults() {
787 let cfg = NixConfig::default();
788 assert_eq!(cfg.timeout_secs, 30);
789 assert_eq!(cfg.max_output_bytes, 10 * 1024 * 1024);
790 assert_eq!(cfg.max_stderr_bytes, 64 * 1024);
791 }
792
793 #[test]
794 fn first_line_extracts_correctly() {
795 assert_eq!(first_line(" \nfoo\nbar"), "foo");
796 assert_eq!(first_line("single"), "single");
797 assert_eq!(first_line(""), "");
798 }
799
800 #[test]
803 #[ignore = "requires nix + warm binary cache; run manually with `cargo test -- --ignored`"]
804 fn nix_python_identity_stage() {
805 let nix_bin = match NixExecutor::find_nix() {
806 Some(p) => p,
807 None => {
808 eprintln!("nix not found, skipping");
809 return;
810 }
811 };
812
813 let cache_dir = std::env::temp_dir().join("noether-nix-integ");
814 let _ = std::fs::create_dir_all(&cache_dir);
815
816 let code = "def execute(x):\n return x";
817 let executor = NixExecutor {
818 nix_bin,
819 cache_dir,
820 config: NixConfig::default(),
821 implementations: {
822 let mut m = HashMap::new();
823 let id = StageId("test_identity".into());
824 m.insert(
825 id.0.clone(),
826 StageImpl {
827 code: code.into(),
828 language: "python".into(),
829 },
830 );
831 m
832 },
833 };
834
835 let id = StageId("test_identity".into());
836 let result = executor.execute(&id, &serde_json::json!({"hello": "world"}));
837 assert_eq!(result.unwrap(), serde_json::json!({"hello": "world"}));
838 }
839
840 #[test]
843 #[ignore = "requires nix + warm binary cache; run manually with `cargo test -- --ignored`"]
844 fn nix_timeout_kills_hanging_stage() {
845 let nix_bin = match NixExecutor::find_nix() {
846 Some(p) => p,
847 None => {
848 eprintln!("nix not found, skipping timeout test");
849 return;
850 }
851 };
852
853 let cache_dir = std::env::temp_dir().join("noether-nix-timeout");
854 let _ = std::fs::create_dir_all(&cache_dir);
855
856 let code = "import time\ndef execute(x):\n time.sleep(9999)\n return x";
857 let executor = NixExecutor {
858 nix_bin,
859 cache_dir,
860 config: NixConfig {
861 timeout_secs: 2,
862 ..NixConfig::default()
863 },
864 implementations: {
865 let mut m = HashMap::new();
866 let id = StageId("hanging".into());
867 m.insert(
868 id.0.clone(),
869 StageImpl {
870 code: code.into(),
871 language: "python".into(),
872 },
873 );
874 m
875 },
876 };
877
878 let id = StageId("hanging".into());
879 let result = executor.execute(&id, &serde_json::json!(null));
880 assert!(
881 matches!(
882 result,
883 Err(ExecutionError::TimedOut {
884 timeout_secs: 2,
885 ..
886 })
887 ),
888 "expected TimedOut, got: {result:?}"
889 );
890 }
891}