1use axum::{
15 extract::State,
16 response::{
17 sse::{Event, KeepAlive, Sse},
18 IntoResponse, Response,
19 },
20 routing::{get, post},
21 Json, Router,
22};
23use futures::stream::{self, Stream};
24use serde_json::{json, Value};
25use std::convert::Infallible;
26use std::time::Duration;
27
28#[derive(Clone, Debug)]
30pub struct LlmMockConfig {
31 pub canned_reply: String,
33 pub default_model: String,
35 pub echo_prompt: bool,
38 pub stream_chunk_delay_ms: u64,
40}
41
42impl Default for LlmMockConfig {
43 fn default() -> Self {
44 Self {
45 canned_reply: "This is a mock response from MockForge's LLM endpoint.".to_string(),
46 default_model: "mockforge-mock-1".to_string(),
47 echo_prompt: true,
48 stream_chunk_delay_ms: 0,
49 }
50 }
51}
52
53pub fn router(config: LlmMockConfig) -> Router {
55 Router::new()
56 .route("/v1/chat/completions", post(chat_completions))
57 .route("/v1/models", get(list_models))
58 .route("/v1/messages", post(anthropic_messages))
59 .with_state(config)
60}
61
62fn approx_tokens(text: &str) -> u32 {
70 text.split_whitespace().count().max(1) as u32
71}
72
73fn last_user_text(messages: &[Value]) -> String {
77 for m in messages.iter().rev() {
78 if m.get("role").and_then(|r| r.as_str()) == Some("user") {
79 return content_to_text(m.get("content"));
80 }
81 }
82 messages.last().map(|m| content_to_text(m.get("content"))).unwrap_or_default()
84}
85
86fn content_to_text(content: Option<&Value>) -> String {
87 match content {
88 Some(Value::String(s)) => s.clone(),
89 Some(Value::Array(parts)) => parts
90 .iter()
91 .filter_map(|p| p.get("text").and_then(|t| t.as_str()))
92 .collect::<Vec<_>>()
93 .join(" "),
94 _ => String::new(),
95 }
96}
97
98fn build_reply(config: &LlmMockConfig, messages: &[Value]) -> String {
100 if config.echo_prompt {
101 let prompt = last_user_text(messages);
102 if !prompt.is_empty() {
103 let trimmed: String = prompt.chars().take(120).collect();
104 return format!("{} (you said: \"{}\")", config.canned_reply, trimmed);
105 }
106 }
107 config.canned_reply.clone()
108}
109
110fn stable_id(prefix: &str, seed: &str) -> String {
114 let mut hash: u64 = 1469598103934665603; for b in seed.bytes() {
116 hash ^= b as u64;
117 hash = hash.wrapping_mul(1099511628211);
118 }
119 format!("{prefix}{hash:016x}")
120}
121
122fn stream_chunks(reply: &str) -> Vec<String> {
125 let mut out = Vec::new();
126 for (i, word) in reply.split_whitespace().enumerate() {
127 if i == 0 {
128 out.push(word.to_string());
129 } else {
130 out.push(format!(" {word}"));
131 }
132 }
133 if out.is_empty() {
134 out.push(reply.to_string());
135 }
136 out
137}
138
139async fn list_models(State(config): State<LlmMockConfig>) -> Json<Value> {
144 Json(json!({
145 "object": "list",
146 "data": [{
147 "id": config.default_model,
148 "object": "model",
149 "created": 0,
150 "owned_by": "mockforge",
151 }],
152 }))
153}
154
155async fn chat_completions(
160 State(config): State<LlmMockConfig>,
161 Json(body): Json<Value>,
162) -> Response {
163 let model = body
164 .get("model")
165 .and_then(|m| m.as_str())
166 .unwrap_or(&config.default_model)
167 .to_string();
168 let messages: Vec<Value> =
169 body.get("messages").and_then(|m| m.as_array()).cloned().unwrap_or_default();
170 let stream = body.get("stream").and_then(|s| s.as_bool()).unwrap_or(false);
171
172 let reply = build_reply(&config, &messages);
173 let prompt_text = messages
174 .iter()
175 .map(|m| content_to_text(m.get("content")))
176 .collect::<Vec<_>>()
177 .join(" ");
178 let prompt_tokens = approx_tokens(&prompt_text);
179 let completion_tokens = approx_tokens(&reply);
180 let id = stable_id("chatcmpl-", &reply);
181
182 if stream {
183 return openai_stream(config, id, model, reply).into_response();
184 }
185
186 Json(json!({
187 "id": id,
188 "object": "chat.completion",
189 "created": 0,
190 "model": model,
191 "choices": [{
192 "index": 0,
193 "message": { "role": "assistant", "content": reply },
194 "finish_reason": "stop",
195 }],
196 "usage": {
197 "prompt_tokens": prompt_tokens,
198 "completion_tokens": completion_tokens,
199 "total_tokens": prompt_tokens + completion_tokens,
200 },
201 }))
202 .into_response()
203}
204
205fn openai_stream(
206 config: LlmMockConfig,
207 id: String,
208 model: String,
209 reply: String,
210) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
211 let mut events: Vec<Event> = Vec::new();
212 events.push(sse_json(&json!({
214 "id": id, "object": "chat.completion.chunk", "created": 0, "model": model,
215 "choices": [{ "index": 0, "delta": { "role": "assistant" }, "finish_reason": Value::Null }],
216 })));
217 for chunk in stream_chunks(&reply) {
219 events.push(sse_json(&json!({
220 "id": id, "object": "chat.completion.chunk", "created": 0, "model": model,
221 "choices": [{ "index": 0, "delta": { "content": chunk }, "finish_reason": Value::Null }],
222 })));
223 }
224 events.push(sse_json(&json!({
226 "id": id, "object": "chat.completion.chunk", "created": 0, "model": model,
227 "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop" }],
228 })));
229 events.push(Event::default().data("[DONE]"));
230
231 sse_response(events, config.stream_chunk_delay_ms)
232}
233
234async fn anthropic_messages(
239 State(config): State<LlmMockConfig>,
240 Json(body): Json<Value>,
241) -> Response {
242 let model = body
243 .get("model")
244 .and_then(|m| m.as_str())
245 .unwrap_or(&config.default_model)
246 .to_string();
247 let messages: Vec<Value> =
248 body.get("messages").and_then(|m| m.as_array()).cloned().unwrap_or_default();
249 let stream = body.get("stream").and_then(|s| s.as_bool()).unwrap_or(false);
250
251 let reply = build_reply(&config, &messages);
252 let prompt_text = messages
253 .iter()
254 .map(|m| content_to_text(m.get("content")))
255 .collect::<Vec<_>>()
256 .join(" ");
257 let input_tokens = approx_tokens(&prompt_text);
258 let output_tokens = approx_tokens(&reply);
259 let id = stable_id("msg_", &reply);
260
261 if stream {
262 return anthropic_stream(config, id, model, reply, input_tokens, output_tokens)
263 .into_response();
264 }
265
266 Json(json!({
267 "id": id,
268 "type": "message",
269 "role": "assistant",
270 "model": model,
271 "content": [{ "type": "text", "text": reply }],
272 "stop_reason": "end_turn",
273 "stop_sequence": Value::Null,
274 "usage": { "input_tokens": input_tokens, "output_tokens": output_tokens },
275 }))
276 .into_response()
277}
278
279fn anthropic_stream(
280 config: LlmMockConfig,
281 id: String,
282 model: String,
283 reply: String,
284 input_tokens: u32,
285 output_tokens: u32,
286) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
287 let mut events: Vec<Event> = Vec::new();
288 events.push(sse_named(
289 "message_start",
290 &json!({
291 "type": "message_start",
292 "message": {
293 "id": id, "type": "message", "role": "assistant", "model": model,
294 "content": [], "stop_reason": Value::Null, "stop_sequence": Value::Null,
295 "usage": { "input_tokens": input_tokens, "output_tokens": 0 },
296 },
297 }),
298 ));
299 events.push(sse_named(
300 "content_block_start",
301 &json!({ "type": "content_block_start", "index": 0, "content_block": { "type": "text", "text": "" } }),
302 ));
303 for chunk in stream_chunks(&reply) {
304 events.push(sse_named(
305 "content_block_delta",
306 &json!({ "type": "content_block_delta", "index": 0, "delta": { "type": "text_delta", "text": chunk } }),
307 ));
308 }
309 events.push(sse_named(
310 "content_block_stop",
311 &json!({ "type": "content_block_stop", "index": 0 }),
312 ));
313 events.push(sse_named(
314 "message_delta",
315 &json!({ "type": "message_delta", "delta": { "stop_reason": "end_turn", "stop_sequence": Value::Null }, "usage": { "output_tokens": output_tokens } }),
316 ));
317 events.push(sse_named("message_stop", &json!({ "type": "message_stop" })));
318
319 sse_response(events, config.stream_chunk_delay_ms)
320}
321
322fn sse_json(value: &Value) -> Event {
327 Event::default().data(value.to_string())
328}
329
330fn sse_named(name: &str, value: &Value) -> Event {
331 Event::default().event(name).data(value.to_string())
332}
333
334fn sse_response(
337 events: Vec<Event>,
338 delay_ms: u64,
339) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
340 let s = stream::unfold(events.into_iter(), move |mut it| async move {
341 let next = it.next()?;
342 if delay_ms > 0 {
343 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
344 }
345 Some((Ok::<Event, Infallible>(next), it))
346 });
347 Sse::new(s).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 fn cfg() -> LlmMockConfig {
355 LlmMockConfig {
356 echo_prompt: false,
357 ..Default::default()
358 }
359 }
360
361 #[test]
362 fn approx_tokens_counts_words() {
363 assert_eq!(approx_tokens("one two three"), 3);
364 assert_eq!(approx_tokens(""), 1); }
366
367 #[test]
368 fn last_user_text_handles_string_and_array_content() {
369 let msgs = vec![
370 json!({"role":"system","content":"be brief"}),
371 json!({"role":"user","content":"hello world"}),
372 ];
373 assert_eq!(last_user_text(&msgs), "hello world");
374 let arr = vec![
375 json!({"role":"user","content":[{"type":"text","text":"a"},{"type":"text","text":"b"}]}),
376 ];
377 assert_eq!(last_user_text(&arr), "a b");
378 }
379
380 #[test]
381 fn echo_prompt_reflects_user_message() {
382 let c = LlmMockConfig {
383 echo_prompt: true,
384 ..Default::default()
385 };
386 let msgs = vec![json!({"role":"user","content":"ping"})];
387 let reply = build_reply(&c, &msgs);
388 assert!(reply.contains("ping"), "reply should echo the prompt: {reply}");
389 }
390
391 #[test]
392 fn stable_id_is_deterministic_and_prefixed() {
393 let a = stable_id("chatcmpl-", "same");
394 let b = stable_id("chatcmpl-", "same");
395 assert_eq!(a, b);
396 assert!(a.starts_with("chatcmpl-"));
397 assert_ne!(stable_id("chatcmpl-", "x"), stable_id("chatcmpl-", "y"));
398 }
399
400 #[test]
401 fn stream_chunks_preserve_leading_space_after_first() {
402 let chunks = stream_chunks("alpha beta gamma");
403 assert_eq!(chunks, vec!["alpha", " beta", " gamma"]);
404 assert_eq!(chunks.concat(), "alpha beta gamma");
405 }
406
407 #[tokio::test]
408 async fn chat_completions_non_stream_shape() {
409 let body = json!({"model":"gpt-x","messages":[{"role":"user","content":"hi there"}]});
410 let resp = chat_completions(State(cfg()), Json(body)).await;
411 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
412 let v: Value = serde_json::from_slice(&bytes).unwrap();
413 assert_eq!(v["object"], "chat.completion");
414 assert_eq!(v["choices"][0]["message"]["role"], "assistant");
415 assert_eq!(v["choices"][0]["finish_reason"], "stop");
416 assert!(v["usage"]["total_tokens"].as_u64().unwrap() >= 2);
417 assert!(v["id"].as_str().unwrap().starts_with("chatcmpl-"));
418 }
419
420 #[tokio::test]
421 async fn anthropic_non_stream_shape() {
422 let body = json!({"model":"claude-x","messages":[{"role":"user","content":"hi"}]});
423 let resp = anthropic_messages(State(cfg()), Json(body)).await;
424 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
425 let v: Value = serde_json::from_slice(&bytes).unwrap();
426 assert_eq!(v["type"], "message");
427 assert_eq!(v["content"][0]["type"], "text");
428 assert_eq!(v["stop_reason"], "end_turn");
429 assert!(v["usage"]["output_tokens"].as_u64().unwrap() >= 1);
430 assert!(v["id"].as_str().unwrap().starts_with("msg_"));
431 }
432
433 #[tokio::test]
434 async fn models_list_shape() {
435 let Json(v) = list_models(State(cfg())).await;
436 assert_eq!(v["object"], "list");
437 assert_eq!(v["data"][0]["owned_by"], "mockforge");
438 }
439}