greentic_runner_host/runner/
mocks.rs

1use anyhow::{Context, Result};
2use hex::encode;
3use parking_lot::Mutex;
4use serde::{Deserialize, Serialize};
5use serde_json::{Value, json};
6use sha2::{Digest, Sha256};
7use std::collections::{BTreeMap, HashMap, HashSet};
8use std::fs;
9use std::io::Write;
10use std::path::{Path, PathBuf};
11use std::sync::{Arc, Weak};
12use url::Url;
13
14pub trait MockEventSink: Send + Sync {
15    fn on_mock_event(&self, capability: &str, provider: &str, payload: &Value);
16}
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19pub struct MocksConfig {
20    pub http: Option<HttpMock>,
21    pub secrets: Option<SecretsMock>,
22    pub kv: Option<KvMock>,
23    pub telemetry: Option<TelemetryMock>,
24    pub mcp_tools: Option<ToolsMock>,
25    pub time: Option<TimeMock>,
26    #[serde(default)]
27    pub net_allowlist: Vec<String>,
28}
29
30#[derive(Clone, Debug, Default, Serialize, Deserialize)]
31pub struct HttpMock {
32    pub record_replay_dir: Option<PathBuf>,
33    pub mode: HttpMockMode,
34    #[serde(default)]
35    pub rewrites: Vec<(String, String)>,
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize, Default)]
39pub enum HttpMockMode {
40    #[default]
41    Off,
42    Replay,
43    Record,
44    RecordReplay,
45    FailOnMiss,
46}
47
48#[derive(Clone, Debug, Default, Serialize, Deserialize)]
49pub struct SecretsMock {
50    pub map: BTreeMap<String, String>,
51}
52
53#[derive(Clone, Debug, Default, Serialize, Deserialize)]
54pub struct KvMock;
55
56#[derive(Clone, Debug, Default, Serialize, Deserialize)]
57pub struct TelemetryMock;
58
59#[derive(Clone, Debug, Default, Serialize, Deserialize)]
60pub struct ToolsMock {
61    pub directory: Option<PathBuf>,
62    pub script_dir: Option<PathBuf>,
63    pub short_circuit: bool,
64}
65
66#[derive(Clone, Debug, Default, Serialize, Deserialize)]
67pub struct TimeMock;
68
69pub struct MockLayer {
70    config: MocksConfig,
71    http: Option<HttpMockRuntime>,
72    sinks: Mutex<Vec<Weak<dyn MockEventSink>>>,
73    net_allowlist: HashSet<String>,
74}
75
76impl MockLayer {
77    pub fn new(config: MocksConfig, run_dir: &Path) -> Result<Self> {
78        let http = if let Some(http_cfg) = &config.http {
79            match http_cfg.mode {
80                HttpMockMode::Off => None,
81                _ => Some(HttpMockRuntime::new(http_cfg, run_dir)?),
82            }
83        } else {
84            None
85        };
86        let net_allowlist = config
87            .net_allowlist
88            .iter()
89            .map(|value| value.to_ascii_lowercase())
90            .collect();
91        Ok(Self {
92            config,
93            http,
94            sinks: Mutex::new(Vec::new()),
95            net_allowlist,
96        })
97    }
98
99    pub fn register_sink(&self, sink: Arc<dyn MockEventSink>) {
100        let mut guard = self.sinks.lock();
101        guard.retain(|entry| entry.upgrade().is_some());
102        guard.push(Arc::downgrade(&sink));
103    }
104
105    fn emit_event(&self, capability: &str, provider: &str, payload: Value) {
106        let mut guard = self.sinks.lock();
107        guard.retain(|entry| entry.upgrade().is_some());
108        for weak in guard.iter() {
109            if let Some(strong) = weak.upgrade() {
110                strong.on_mock_event(capability, provider, &payload);
111            }
112        }
113    }
114
115    pub fn secrets_lookup(&self, key: &str) -> Option<String> {
116        let secrets = self.config.secrets.as_ref()?;
117        secrets.map.get(key).map(|value| {
118            self.emit_event("secrets", "mock", json!({ "key": key, "source": "map" }));
119            value.clone()
120        })
121    }
122
123    pub fn telemetry_drain(&self, fields: &[(&str, &str)]) -> bool {
124        if self.config.telemetry.is_none() {
125            return false;
126        }
127        let entries = fields
128            .iter()
129            .map(|(k, v)| (k.to_string(), v.to_string()))
130            .collect::<BTreeMap<_, _>>();
131        self.emit_event("telemetry", "mock", json!({ "fields": entries }));
132        true
133    }
134
135    pub fn tool_short_circuit(&self, tool: &str, action: &str) -> Option<Result<Value>> {
136        let tools = self.config.mcp_tools.as_ref()?;
137        if !tools.short_circuit {
138            return None;
139        }
140        let script_dir = tools.script_dir.as_ref()?;
141        let filename = format!("{}__{}.json", sanitize(tool), sanitize(action));
142        let path = script_dir.join(filename);
143        let body = fs::read_to_string(&path)
144            .with_context(|| format!("failed to read mock script {}", path.display()));
145        let result = body.and_then(|text| {
146            serde_json::from_str(&text)
147                .with_context(|| format!("mock script {} is not valid json", path.display()))
148        });
149        Some(match result {
150            Ok(value) => {
151                self.emit_event(
152                    "tools",
153                    "mock",
154                    json!({ "tool": tool, "action": action, "script": path }),
155                );
156                Ok(value)
157            }
158            Err(err) => Err(err),
159        })
160    }
161
162    pub fn http_begin(&self, request: &HttpMockRequest) -> HttpDecision {
163        let runtime = match &self.http {
164            Some(runtime) => runtime,
165            None => return HttpDecision::Passthrough { record: false },
166        };
167
168        if let Some(response) = runtime.replay(request) {
169            self.emit_event(
170                "http",
171                "mock",
172                json!({
173                    "url": request.url,
174                    "method": request.method,
175                    "mode": "replay"
176                }),
177            );
178            return HttpDecision::Mock(response);
179        }
180
181        if !self.allow_host(&request.url) {
182            return HttpDecision::Deny(format!("host {} not present in allowlist", request.host));
183        }
184
185        if runtime.should_record() {
186            HttpDecision::Passthrough { record: true }
187        } else {
188            match runtime.mode {
189                HttpMockMode::RecordReplay | HttpMockMode::Record => {
190                    HttpDecision::Passthrough { record: false }
191                }
192                _ => HttpDecision::Deny("no recorded response".into()),
193            }
194        }
195    }
196
197    pub fn http_record(&self, request: &HttpMockRequest, response: &HttpMockResponse) {
198        if let Some(runtime) = &self.http
199            && runtime.record(request, response).is_ok()
200        {
201            self.emit_event(
202                "http",
203                "mock",
204                json!({
205                    "url": request.url,
206                    "method": request.method,
207                    "mode": "record"
208                }),
209            );
210        }
211    }
212
213    fn allow_host(&self, url: &str) -> bool {
214        if self.net_allowlist.is_empty() {
215            return false;
216        }
217        Url::parse(url)
218            .ok()
219            .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
220            .map(|host| self.net_allowlist.contains(&host))
221            .unwrap_or(false)
222    }
223}
224
225pub enum HttpDecision {
226    Mock(HttpMockResponse),
227    Deny(String),
228    Passthrough { record: bool },
229}
230
231pub struct HttpMockRequest {
232    pub method: String,
233    pub url: String,
234    pub host: String,
235    pub fingerprint: String,
236}
237
238impl HttpMockRequest {
239    pub fn new(method: &str, url: &str, body: Option<&[u8]>) -> Result<Self> {
240        let parsed = Url::parse(url).with_context(|| format!("invalid url {url}"))?;
241        let host = parsed
242            .host_str()
243            .map(|s| s.to_ascii_lowercase())
244            .unwrap_or_default();
245        let mut hasher = Sha256::new();
246        hasher.update(method.as_bytes());
247        hasher.update(url.as_bytes());
248        if let Some(body) = body {
249            hasher.update(body);
250        }
251        let fingerprint = encode(hasher.finalize());
252        Ok(Self {
253            method: method.to_string(),
254            url: url.to_string(),
255            host,
256            fingerprint,
257        })
258    }
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize)]
262pub struct HttpMockResponse {
263    pub status: u16,
264    pub headers: BTreeMap<String, String>,
265    pub body: Option<String>,
266}
267
268impl HttpMockResponse {
269    pub fn new(status: u16, headers: BTreeMap<String, String>, body: Option<String>) -> Self {
270        Self {
271            status,
272            headers,
273            body,
274        }
275    }
276}
277
278struct HttpMockRuntime {
279    mode: HttpMockMode,
280    dir: Option<PathBuf>,
281    entries: Mutex<HashMap<String, HttpMockResponse>>,
282}
283
284impl HttpMockRuntime {
285    fn new(config: &HttpMock, run_dir: &Path) -> Result<Self> {
286        let dir = config
287            .record_replay_dir
288            .clone()
289            .or_else(|| Some(run_dir.join("cassettes/http")));
290        if let Some(path) = dir.as_ref() {
291            let _ = fs::create_dir_all(path);
292        }
293        let entries = Mutex::new(Self::load_entries(dir.as_ref())?);
294        Ok(Self {
295            mode: config.mode.clone(),
296            dir,
297            entries,
298        })
299    }
300
301    fn load_entries(dir: Option<&PathBuf>) -> Result<HashMap<String, HttpMockResponse>> {
302        let mut map = HashMap::new();
303        if let Some(dir) = dir
304            && dir.exists()
305        {
306            for entry in fs::read_dir(dir)? {
307                let entry = entry?;
308                let path = entry.path();
309                if path.extension().and_then(|s| s.to_str()) != Some("json") {
310                    continue;
311                }
312                let bytes = fs::read(&path)?;
313                let resp: HttpMockResponse = serde_json::from_slice(&bytes)?;
314                if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
315                    map.insert(stem.to_string(), resp);
316                }
317            }
318        }
319        Ok(map)
320    }
321
322    fn replay(&self, req: &HttpMockRequest) -> Option<HttpMockResponse> {
323        let guard = self.entries.lock();
324        guard.get(&req.fingerprint).cloned()
325    }
326
327    fn should_record(&self) -> bool {
328        matches!(self.mode, HttpMockMode::Record | HttpMockMode::RecordReplay)
329    }
330
331    fn record(&self, req: &HttpMockRequest, resp: &HttpMockResponse) -> Result<()> {
332        if !self.should_record() {
333            return Ok(());
334        }
335        let dir = match &self.dir {
336            Some(dir) => dir,
337            None => return Ok(()),
338        };
339        let path = dir.join(format!("{}.json", &req.fingerprint));
340        let mut file = fs::File::create(&path)
341            .with_context(|| format!("failed to write {}", path.display()))?;
342        let body = serde_json::to_vec_pretty(resp)?;
343        file.write_all(&body)?;
344        self.entries
345            .lock()
346            .insert(req.fingerprint.clone(), resp.clone());
347        Ok(())
348    }
349}
350
351fn sanitize(value: &str) -> String {
352    value
353        .chars()
354        .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' })
355        .collect()
356}