1use anyhow::{Context, Result};
2use hex::encode;
3use parking_lot::Mutex;
4use serde::{Deserialize, Serialize};
5use serde_json::{Value, json};
6use sha2::{Digest, Sha256};
7use std::collections::{BTreeMap, HashMap, HashSet};
8use std::fs;
9use std::io::Write;
10use std::path::{Path, PathBuf};
11use std::sync::{Arc, Weak};
12use url::Url;
13
14pub trait MockEventSink: Send + Sync {
15 fn on_mock_event(&self, capability: &str, provider: &str, payload: &Value);
16}
17
18#[derive(Clone, Debug, Default, Serialize, Deserialize)]
19pub struct MocksConfig {
20 pub http: Option<HttpMock>,
21 pub secrets: Option<SecretsMock>,
22 pub kv: Option<KvMock>,
23 pub telemetry: Option<TelemetryMock>,
24 pub mcp_tools: Option<ToolsMock>,
25 pub time: Option<TimeMock>,
26 #[serde(default)]
27 pub net_allowlist: Vec<String>,
28}
29
30#[derive(Clone, Debug, Default, Serialize, Deserialize)]
31pub struct HttpMock {
32 pub record_replay_dir: Option<PathBuf>,
33 pub mode: HttpMockMode,
34 #[serde(default)]
35 pub rewrites: Vec<(String, String)>,
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize, Default)]
39pub enum HttpMockMode {
40 #[default]
41 Off,
42 Replay,
43 Record,
44 RecordReplay,
45 FailOnMiss,
46}
47
48#[derive(Clone, Debug, Default, Serialize, Deserialize)]
49pub struct SecretsMock {
50 pub map: BTreeMap<String, String>,
51}
52
53#[derive(Clone, Debug, Default, Serialize, Deserialize)]
54pub struct KvMock;
55
56#[derive(Clone, Debug, Default, Serialize, Deserialize)]
57pub struct TelemetryMock;
58
59#[derive(Clone, Debug, Default, Serialize, Deserialize)]
60pub struct ToolsMock {
61 pub directory: Option<PathBuf>,
62 pub script_dir: Option<PathBuf>,
63 pub short_circuit: bool,
64}
65
66#[derive(Clone, Debug, Default, Serialize, Deserialize)]
67pub struct TimeMock;
68
69pub struct MockLayer {
70 config: MocksConfig,
71 http: Option<HttpMockRuntime>,
72 sinks: Mutex<Vec<Weak<dyn MockEventSink>>>,
73 net_allowlist: HashSet<String>,
74}
75
76impl MockLayer {
77 pub fn new(config: MocksConfig, run_dir: &Path) -> Result<Self> {
78 let http = if let Some(http_cfg) = &config.http {
79 match http_cfg.mode {
80 HttpMockMode::Off => None,
81 _ => Some(HttpMockRuntime::new(http_cfg, run_dir)?),
82 }
83 } else {
84 None
85 };
86 let net_allowlist = config
87 .net_allowlist
88 .iter()
89 .map(|value| value.to_ascii_lowercase())
90 .collect();
91 Ok(Self {
92 config,
93 http,
94 sinks: Mutex::new(Vec::new()),
95 net_allowlist,
96 })
97 }
98
99 pub fn register_sink(&self, sink: Arc<dyn MockEventSink>) {
100 let mut guard = self.sinks.lock();
101 guard.retain(|entry| entry.upgrade().is_some());
102 guard.push(Arc::downgrade(&sink));
103 }
104
105 fn emit_event(&self, capability: &str, provider: &str, payload: Value) {
106 let mut guard = self.sinks.lock();
107 guard.retain(|entry| entry.upgrade().is_some());
108 for weak in guard.iter() {
109 if let Some(strong) = weak.upgrade() {
110 strong.on_mock_event(capability, provider, &payload);
111 }
112 }
113 }
114
115 pub fn secrets_lookup(&self, key: &str) -> Option<String> {
116 let secrets = self.config.secrets.as_ref()?;
117 secrets.map.get(key).map(|value| {
118 self.emit_event("secrets", "mock", json!({ "key": key, "source": "map" }));
119 value.clone()
120 })
121 }
122
123 pub fn telemetry_drain(&self, fields: &[(&str, &str)]) -> bool {
124 if self.config.telemetry.is_none() {
125 return false;
126 }
127 let entries = fields
128 .iter()
129 .map(|(k, v)| (k.to_string(), v.to_string()))
130 .collect::<BTreeMap<_, _>>();
131 self.emit_event("telemetry", "mock", json!({ "fields": entries }));
132 true
133 }
134
135 pub fn tool_short_circuit(&self, tool: &str, action: &str) -> Option<Result<Value>> {
136 let tools = self.config.mcp_tools.as_ref()?;
137 if !tools.short_circuit {
138 return None;
139 }
140 let script_dir = tools.script_dir.as_ref()?;
141 let filename = format!("{}__{}.json", sanitize(tool), sanitize(action));
142 let path = script_dir.join(filename);
143 let body = fs::read_to_string(&path)
144 .with_context(|| format!("failed to read mock script {}", path.display()));
145 let result = body.and_then(|text| {
146 serde_json::from_str(&text)
147 .with_context(|| format!("mock script {} is not valid json", path.display()))
148 });
149 Some(match result {
150 Ok(value) => {
151 self.emit_event(
152 "tools",
153 "mock",
154 json!({ "tool": tool, "action": action, "script": path }),
155 );
156 Ok(value)
157 }
158 Err(err) => Err(err),
159 })
160 }
161
162 pub fn http_begin(&self, request: &HttpMockRequest) -> HttpDecision {
163 let runtime = match &self.http {
164 Some(runtime) => runtime,
165 None => return HttpDecision::Passthrough { record: false },
166 };
167
168 if let Some(response) = runtime.replay(request) {
169 self.emit_event(
170 "http",
171 "mock",
172 json!({
173 "url": request.url,
174 "method": request.method,
175 "mode": "replay"
176 }),
177 );
178 return HttpDecision::Mock(response);
179 }
180
181 if !self.allow_host(&request.url) {
182 return HttpDecision::Deny(format!("host {} not present in allowlist", request.host));
183 }
184
185 if runtime.should_record() {
186 HttpDecision::Passthrough { record: true }
187 } else {
188 match runtime.mode {
189 HttpMockMode::RecordReplay | HttpMockMode::Record => {
190 HttpDecision::Passthrough { record: false }
191 }
192 _ => HttpDecision::Deny("no recorded response".into()),
193 }
194 }
195 }
196
197 pub fn http_record(&self, request: &HttpMockRequest, response: &HttpMockResponse) {
198 if let Some(runtime) = &self.http
199 && runtime.record(request, response).is_ok()
200 {
201 self.emit_event(
202 "http",
203 "mock",
204 json!({
205 "url": request.url,
206 "method": request.method,
207 "mode": "record"
208 }),
209 );
210 }
211 }
212
213 fn allow_host(&self, url: &str) -> bool {
214 if self.net_allowlist.is_empty() {
215 return false;
216 }
217 Url::parse(url)
218 .ok()
219 .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
220 .map(|host| self.net_allowlist.contains(&host))
221 .unwrap_or(false)
222 }
223}
224
225pub enum HttpDecision {
226 Mock(HttpMockResponse),
227 Deny(String),
228 Passthrough { record: bool },
229}
230
231pub struct HttpMockRequest {
232 pub method: String,
233 pub url: String,
234 pub host: String,
235 pub fingerprint: String,
236}
237
238impl HttpMockRequest {
239 pub fn new(method: &str, url: &str, body: Option<&[u8]>) -> Result<Self> {
240 let parsed = Url::parse(url).with_context(|| format!("invalid url {url}"))?;
241 let host = parsed
242 .host_str()
243 .map(|s| s.to_ascii_lowercase())
244 .unwrap_or_default();
245 let mut hasher = Sha256::new();
246 hasher.update(method.as_bytes());
247 hasher.update(url.as_bytes());
248 if let Some(body) = body {
249 hasher.update(body);
250 }
251 let fingerprint = encode(hasher.finalize());
252 Ok(Self {
253 method: method.to_string(),
254 url: url.to_string(),
255 host,
256 fingerprint,
257 })
258 }
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize)]
262pub struct HttpMockResponse {
263 pub status: u16,
264 pub headers: BTreeMap<String, String>,
265 pub body: Option<String>,
266}
267
268impl HttpMockResponse {
269 pub fn new(status: u16, headers: BTreeMap<String, String>, body: Option<String>) -> Self {
270 Self {
271 status,
272 headers,
273 body,
274 }
275 }
276}
277
278struct HttpMockRuntime {
279 mode: HttpMockMode,
280 dir: Option<PathBuf>,
281 entries: Mutex<HashMap<String, HttpMockResponse>>,
282}
283
284impl HttpMockRuntime {
285 fn new(config: &HttpMock, run_dir: &Path) -> Result<Self> {
286 let dir = config
287 .record_replay_dir
288 .clone()
289 .or_else(|| Some(run_dir.join("cassettes/http")));
290 if let Some(path) = dir.as_ref() {
291 let _ = fs::create_dir_all(path);
292 }
293 let entries = Mutex::new(Self::load_entries(dir.as_ref())?);
294 Ok(Self {
295 mode: config.mode.clone(),
296 dir,
297 entries,
298 })
299 }
300
301 fn load_entries(dir: Option<&PathBuf>) -> Result<HashMap<String, HttpMockResponse>> {
302 let mut map = HashMap::new();
303 if let Some(dir) = dir
304 && dir.exists()
305 {
306 for entry in fs::read_dir(dir)? {
307 let entry = entry?;
308 let path = entry.path();
309 if path.extension().and_then(|s| s.to_str()) != Some("json") {
310 continue;
311 }
312 let bytes = fs::read(&path)?;
313 let resp: HttpMockResponse = serde_json::from_slice(&bytes)?;
314 if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
315 map.insert(stem.to_string(), resp);
316 }
317 }
318 }
319 Ok(map)
320 }
321
322 fn replay(&self, req: &HttpMockRequest) -> Option<HttpMockResponse> {
323 let guard = self.entries.lock();
324 guard.get(&req.fingerprint).cloned()
325 }
326
327 fn should_record(&self) -> bool {
328 matches!(self.mode, HttpMockMode::Record | HttpMockMode::RecordReplay)
329 }
330
331 fn record(&self, req: &HttpMockRequest, resp: &HttpMockResponse) -> Result<()> {
332 if !self.should_record() {
333 return Ok(());
334 }
335 let dir = match &self.dir {
336 Some(dir) => dir,
337 None => return Ok(()),
338 };
339 let path = dir.join(format!("{}.json", &req.fingerprint));
340 let mut file = fs::File::create(&path)
341 .with_context(|| format!("failed to write {}", path.display()))?;
342 let body = serde_json::to_vec_pretty(resp)?;
343 file.write_all(&body)?;
344 self.entries
345 .lock()
346 .insert(req.fingerprint.clone(), resp.clone());
347 Ok(())
348 }
349}
350
351fn sanitize(value: &str) -> String {
352 value
353 .chars()
354 .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' })
355 .collect()
356}