Skip to main content

observer_rust_host/
lib.rs

1// SPDX-FileCopyrightText: 2026 Alexander R. Croft
2// SPDX-License-Identifier: GPL-3.0-or-later
3
4use anyhow::{anyhow, Result};
5use base64::engine::general_purpose::STANDARD;
6use base64::Engine;
7use clap::{Parser, Subcommand};
8use observer_rust::{
9    sorted_validated_tests, Registry, TelemetryEntry as RustTelemetryEntry,
10    TelemetryValue as RustTelemetryValue, TestContext, TestOutcome, TestRegistration,
11};
12use serde::{Deserialize, Serialize};
13use serde_json::Number;
14use std::panic::{catch_unwind, AssertUnwindSafe};
15use std::process::{Command as ProcessCommand, Stdio};
16use std::thread;
17use std::time::{Duration, Instant};
18
19#[derive(Debug, Parser)]
20#[command(name = "observer-rust-host")]
21#[command(about = "Rust provider host for Observer")]
22struct Cli {
23    #[command(subcommand)]
24    command: Command,
25}
26
27#[derive(Debug, Subcommand)]
28enum Command {
29    List,
30    Run {
31        #[arg(long)]
32        target: String,
33        #[arg(long = "timeout-ms")]
34        timeout_ms: u32,
35    },
36    #[command(hide = true, name = "__exec-target")]
37    ExecTarget {
38        #[arg(long)]
39        provider: String,
40        #[arg(long)]
41        target: String,
42    },
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
46pub struct ListPayload {
47    pub provider: String,
48    pub tests: Vec<ListEntry>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
52pub struct ListEntry {
53    pub name: String,
54    pub target: String,
55}
56
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
58pub struct RunPayload {
59    pub provider: String,
60    pub target: String,
61    pub exit: i32,
62    pub out_b64: String,
63    pub err_b64: String,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub telemetry: Option<Vec<TelemetryEntry>>,
66}
67
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
69pub struct TelemetryEntry {
70    pub name: String,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub unit: Option<String>,
73    #[serde(flatten)]
74    pub value: TelemetryValue,
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
78#[serde(tag = "kind", rename_all = "snake_case")]
79pub enum TelemetryValue {
80    Metric { value: Number },
81    Vector { values: Vec<Number> },
82    Tag { value: String },
83}
84
85pub fn list_payload<R: Registry>(provider: &str) -> Result<ListPayload> {
86    let tests = sorted_validated_tests::<R>().map_err(|error| anyhow!(error))?;
87    Ok(ListPayload {
88        provider: provider.to_owned(),
89        tests: tests
90            .into_iter()
91            .map(|test| ListEntry {
92                name: test.canonical_name.to_owned(),
93                target: test.target.to_owned(),
94            })
95            .collect(),
96    })
97}
98
99pub fn run_payload<R: Registry>(provider: &str, target: &str, timeout_ms: u32) -> Result<RunPayload> {
100    let _ = timeout_ms;
101    run_payload_direct::<R>(provider, target)
102}
103
104fn run_payload_direct<R: Registry>(provider: &str, target: &str) -> Result<RunPayload> {
105    let tests = sorted_validated_tests::<R>().map_err(|error| anyhow!(error))?;
106    let test = tests
107        .into_iter()
108        .find(|test| test.target == target)
109        .ok_or_else(|| anyhow!("unknown test target `{target}`"))?;
110
111    let outcome = execute_registration(test);
112    Ok(RunPayload {
113        provider: provider.to_owned(),
114        target: target.to_owned(),
115        exit: outcome.exit,
116        out_b64: STANDARD.encode(&outcome.out),
117        err_b64: STANDARD.encode(&outcome.err),
118        telemetry: payload_telemetry_from_outcome(outcome),
119    })
120}
121
122fn run_payload_with_timeout<R: Registry>(provider: &str, target: &str, timeout_ms: u32) -> Result<RunPayload> {
123    let started_at = Instant::now();
124    let usage_before = child_resource_usage();
125    let current_exe = std::env::current_exe()?;
126    let mut child = ProcessCommand::new(current_exe)
127        .arg("__exec-target")
128        .arg("--provider")
129        .arg(provider)
130        .arg("--target")
131        .arg(target)
132        .stdout(Stdio::piped())
133        .stderr(Stdio::piped())
134        .spawn()?;
135
136    let deadline = Instant::now() + Duration::from_millis(u64::from(timeout_ms));
137    loop {
138        match child.try_wait()? {
139            Some(status) => {
140                let output = child.wait_with_output()?;
141                let usage_after = child_resource_usage();
142                if !status.success() {
143                    let stderr = String::from_utf8_lossy(&output.stderr).trim().to_owned();
144                    return Err(anyhow!(if stderr.is_empty() {
145                        "provider host failed".to_owned()
146                    } else {
147                        stderr
148                    }));
149                }
150                let mut payload: RunPayload = serde_json::from_slice(&output.stdout)?;
151                append_runner_telemetry(&mut payload, started_at, usage_before, usage_after);
152                return Ok(payload);
153            }
154            None => {
155                if Instant::now() >= deadline {
156                    let _ = child.kill();
157                    let _ = child.wait();
158                    return Err(anyhow!("timeout"));
159                }
160                thread::sleep(Duration::from_millis(10));
161            }
162        }
163    }
164}
165
166fn payload_telemetry_from_outcome(outcome: TestOutcome) -> Option<Vec<TelemetryEntry>> {
167    let mut telemetry = Vec::new();
168    for entry in outcome.telemetry {
169        if let Some(entry) = convert_test_telemetry(entry) {
170            telemetry.push(entry);
171        }
172    }
173    if telemetry.is_empty() {
174        None
175    } else {
176        Some(telemetry)
177    }
178}
179
180fn convert_test_telemetry(entry: RustTelemetryEntry) -> Option<TelemetryEntry> {
181    let value = match entry.value {
182        RustTelemetryValue::Metric(value) => TelemetryValue::Metric { value },
183        RustTelemetryValue::Vector(values) => TelemetryValue::Vector { values },
184        RustTelemetryValue::Tag(value) => TelemetryValue::Tag { value },
185    };
186
187    Some(TelemetryEntry {
188        name: entry.name,
189        unit: None,
190        value,
191    })
192}
193
194fn append_runner_telemetry(
195    payload: &mut RunPayload,
196    started_at: Instant,
197    usage_before: Option<ChildResourceUsage>,
198    usage_after: Option<ChildResourceUsage>,
199) {
200    let mut telemetry = runner_telemetry(started_at, usage_before, usage_after);
201    if let Some(existing) = payload.telemetry.take() {
202        telemetry.extend(existing);
203    }
204    if telemetry.is_empty() {
205        payload.telemetry = None;
206    } else {
207        payload.telemetry = Some(telemetry);
208    }
209}
210
211fn runner_telemetry(
212    started_at: Instant,
213    usage_before: Option<ChildResourceUsage>,
214    usage_after: Option<ChildResourceUsage>,
215) -> Vec<TelemetryEntry> {
216    let mut telemetry = vec![TelemetryEntry {
217        name: "wall_time_ns".to_owned(),
218        unit: Some("ns".to_owned()),
219        value: TelemetryValue::Metric {
220            value: Number::from(started_at.elapsed().as_nanos() as u64),
221        },
222    }];
223
224    if let (Some(before), Some(after)) = (usage_before, usage_after) {
225        telemetry.push(TelemetryEntry {
226            name: "cpu_user_ns".to_owned(),
227            unit: Some("ns".to_owned()),
228            value: TelemetryValue::Metric {
229                value: Number::from(after.user_ns.saturating_sub(before.user_ns)),
230            },
231        });
232        telemetry.push(TelemetryEntry {
233            name: "cpu_system_ns".to_owned(),
234            unit: Some("ns".to_owned()),
235            value: TelemetryValue::Metric {
236                value: Number::from(after.system_ns.saturating_sub(before.system_ns)),
237            },
238        });
239        telemetry.push(TelemetryEntry {
240            name: "peak_rss_bytes".to_owned(),
241            unit: Some("bytes".to_owned()),
242            value: TelemetryValue::Metric {
243                value: Number::from(after.peak_rss_bytes),
244            },
245        });
246    }
247
248    telemetry
249}
250
251#[derive(Debug, Clone, Copy)]
252struct ChildResourceUsage {
253    user_ns: u64,
254    system_ns: u64,
255    peak_rss_bytes: u64,
256}
257
258#[cfg(unix)]
259fn child_resource_usage() -> Option<ChildResourceUsage> {
260    let mut usage = std::mem::MaybeUninit::<libc::rusage>::zeroed();
261    let status = unsafe { libc::getrusage(libc::RUSAGE_CHILDREN, usage.as_mut_ptr()) };
262    if status != 0 {
263        return None;
264    }
265
266    let usage = unsafe { usage.assume_init() };
267    Some(ChildResourceUsage {
268        user_ns: timeval_to_ns(usage.ru_utime),
269        system_ns: timeval_to_ns(usage.ru_stime),
270        peak_rss_bytes: peak_rss_bytes(usage.ru_maxrss),
271    })
272}
273
274#[cfg(not(unix))]
275fn child_resource_usage() -> Option<ChildResourceUsage> {
276    None
277}
278
279#[cfg(unix)]
280fn timeval_to_ns(value: libc::timeval) -> u64 {
281    let secs = u64::try_from(value.tv_sec).unwrap_or(0);
282    let micros = u64::try_from(value.tv_usec).unwrap_or(0);
283    secs.saturating_mul(1_000_000_000)
284        .saturating_add(micros.saturating_mul(1_000))
285}
286
287#[cfg(unix)]
288fn peak_rss_bytes(value: libc::c_long) -> u64 {
289    let value = u64::try_from(value).unwrap_or(0);
290    #[cfg(target_os = "linux")]
291    {
292        value.saturating_mul(1024)
293    }
294    #[cfg(not(target_os = "linux"))]
295    {
296        value
297    }
298}
299
300pub fn run_cli<R: Registry>(provider: &'static str) -> Result<()> {
301    run_cli_from::<R, _>(provider, std::env::args())
302}
303
304pub fn run_cli_from<R, I>(provider: &'static str, args: I) -> Result<()>
305where
306    R: Registry,
307    I: IntoIterator<Item = String>,
308{
309    let cli = Cli::parse_from(args);
310    match cli.command {
311        Command::List => println!("{}", serde_json::to_string(&list_payload::<R>(provider)?)?),
312        Command::Run { target, timeout_ms } => {
313            println!(
314                "{}",
315                serde_json::to_string(&run_payload_with_timeout::<R>(provider, &target, timeout_ms)?)?
316            )
317        }
318        Command::ExecTarget { provider, target } => println!(
319            "{}",
320            serde_json::to_string(&run_payload_direct::<R>(&provider, &target)?)?
321        ),
322    }
323    Ok(())
324}
325
326fn execute_registration(test: &TestRegistration) -> TestOutcome {
327    let mut ctx = TestContext::new();
328    let result = catch_unwind(AssertUnwindSafe(|| {
329        (test.function)(&mut ctx);
330    }));
331
332    match result {
333        Ok(()) => ctx.finish(),
334        Err(payload) => {
335            ctx.set_exit(1);
336            ctx.write_err(panic_payload_message(&payload).as_bytes());
337            ctx.finish()
338        }
339    }
340}
341
342fn panic_payload_message(payload: &Box<dyn std::any::Any + Send>) -> String {
343    if let Some(message) = payload.downcast_ref::<&str>() {
344        return format!("panic: {message}\n");
345    }
346    if let Some(message) = payload.downcast_ref::<String>() {
347        return format!("panic: {message}\n");
348    }
349    "panic: non-string payload\n".to_owned()
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use observer_rust::{Registry, TestContext, TestRegistration};
356
357    fn passing(ctx: &mut TestContext) {
358        ctx.write_out(b"ok\n");
359        assert!(ctx.emit_metric("case_metric", 7.0));
360        assert!(ctx.emit_vector("latency_ns", &[1.0, 2.0, 3.0]));
361        assert!(ctx.emit_tag("resource_path", "fixtures/config.json"));
362    }
363
364    fn panicking(_ctx: &mut TestContext) {
365        panic!("boom");
366    }
367
368    struct HostRegistry;
369
370    impl Registry for HostRegistry {
371        fn tests() -> &'static [TestRegistration] {
372            static TESTS: [TestRegistration; 2] = [
373                TestRegistration {
374                    canonical_name: "B::Second",
375                    target: "b::second",
376                    function: panicking,
377                    file: file!(),
378                    line: line!(),
379                    module_path: module_path!(),
380                },
381                TestRegistration {
382                    canonical_name: "A::First",
383                    target: "a::first",
384                    function: passing,
385                    file: file!(),
386                    line: line!(),
387                    module_path: module_path!(),
388                },
389            ];
390            &TESTS
391        }
392    }
393
394    #[test]
395    fn list_payload_is_sorted_by_canonical_name() {
396        let payload = list_payload::<HostRegistry>("rust").expect("list should succeed");
397        assert_eq!(payload.tests[0].name, "A::First");
398        assert_eq!(payload.tests[1].name, "B::Second");
399    }
400
401    #[test]
402    fn run_payload_catches_panics_as_test_failures() {
403        let payload = run_payload::<HostRegistry>("rust", "b::second", 1000)
404            .expect("run should succeed");
405        assert_eq!(payload.exit, 1);
406        let err = String::from_utf8(STANDARD.decode(payload.err_b64).expect("base64 should decode"))
407            .expect("stderr should be utf-8");
408        assert!(err.contains("panic: boom"));
409    }
410
411    #[test]
412    fn run_payload_includes_test_emitted_telemetry() {
413        let payload = run_payload::<HostRegistry>("rust", "a::first", 1000)
414            .expect("run should succeed");
415        let telemetry = payload.telemetry.expect("telemetry should exist");
416        assert_eq!(telemetry.len(), 3);
417        assert_eq!(telemetry[0].name, "case_metric");
418        assert_eq!(telemetry[2].name, "resource_path");
419    }
420}