1use crate::error::MockError;
2use mockito;
3use potato_type::anthropic::AnthropicMessageResponse;
4use potato_type::google::v1::generate::{DataNum, GenerateContentResponse};
5use potato_type::google::GeminiEmbeddingResponse;
6use potato_type::openai::v1::embedding::OpenAIEmbeddingResponse;
7use potato_type::openai::v1::OpenAIChatResponse;
8use potato_type::openai::{ChatMessage, ContentPart, TextContentPart};
9use potato_type::prompt::{MessageNum, Prompt, Role, Score};
10use potato_type::StructuredOutput;
11use pyo3::prelude::*;
12use rand::Rng;
13use serde_json;
14
15pub const OPENAI_EMBEDDING_RESPONSE: &str = include_str!("assets/openai/embedding_response.json");
16
17pub const GEMINI_EMBEDDING_RESPONSE: &str = include_str!("assets/gemini/embedding_response.json");
18
19pub const OPENAI_CHAT_COMPLETION_RESPONSE: &str =
20 include_str!("assets/openai/openai_chat_completion_response.json");
21
22pub const OPENAI_CHAT_STRUCTURED_RESPONSE: &str =
23 include_str!("assets/openai/chat_completion_structured_response.json");
24
25pub const OPENAI_CHAT_STRUCTURED_SCORE_RESPONSE: &str =
26 include_str!("assets/openai/chat_completion_structured_score_response.json");
27
28pub const OPENAI_CHAT_STRUCTURED_RESPONSE_PARAMS: &str =
29 include_str!("assets/openai/chat_completion_structured_response_params.json");
30
31pub const OPENAI_CHAT_STRUCTURED_TASK_OUTPUT: &str =
32 include_str!("assets/openai/chat_completion_structured_task_output.json");
33
34pub const GEMINI_CHAT_COMPLETION_RESPONSE: &str =
35 include_str!("assets/gemini/chat_completion.json");
36
37pub const GEMINI_CHAT_COMPLETION_RESPONSE_WITH_SCORE: &str =
38 include_str!("assets/gemini/chat_completion_with_score.json");
39
40pub const ANTHROPIC_MESSAGE_RESPONSE: &str =
41 include_str!("assets/anthropic/message_completion.json");
42
43pub const ANTHROPIC_MESSAGE_STRUCTURED_RESPONSE: &str =
44 include_str!("assets/anthropic/message_structured_completion.json");
45
46pub const ANTHROPIC_MESSAGE_STRUCTURED_TASK_OUTPUT: &str =
47 include_str!("assets/anthropic/message_structured_completion_tasks.json");
48
49fn randomize_openai_embedding_response(
50 response: OpenAIEmbeddingResponse,
51) -> OpenAIEmbeddingResponse {
52 let mut cloned_response = response.clone();
54 let mut rng = rand::rng();
55 let embedding: Vec<f32> = (0..512).map(|_| rng.random_range(-1.0..1.0)).collect();
56 cloned_response.data[0].embedding = embedding;
57 cloned_response
58}
59
60fn randomize_gemini_embedding_response(
61 response: GeminiEmbeddingResponse,
62) -> GeminiEmbeddingResponse {
63 let mut cloned_response = response.clone();
64 let mut rng = rand::rng();
65 let embedding: Vec<f32> = (0..512).map(|_| rng.random_range(-1.0..1.0)).collect();
66 cloned_response.embedding.values = embedding;
67 cloned_response
68}
69
70fn randomize_structured_openai_score_response(response: &OpenAIChatResponse) -> OpenAIChatResponse {
71 let mut cloned_response = response.clone();
72 let mut rng = rand::rng();
73
74 let score = rng.random_range(1..=5);
76
77 let reasons = [
79 "The code is excellent and follows best practices.",
80 "The implementation is solid with minor improvements possible.",
81 "The code works but could use some optimization.",
82 "The solution is functional but needs refactoring.",
83 "The code has significant issues that need addressing.",
84 ];
85 let reason = reasons[rng.random_range(0..reasons.len())];
86
87 cloned_response.choices[0].message.content = Some(format!(
88 "{{ \"score\": {}, \"reason\": \"{}\" }}",
89 score, reason
90 ));
91
92 cloned_response
93}
94
95fn randomize_gemini_score_response(response: GenerateContentResponse) -> GenerateContentResponse {
96 let mut cloned_response = response.clone();
97 let mut rng = rand::rng();
98
99 let score = rng.random_range(1..=5);
101
102 let reasons = [
104 "The model performed exceptionally well on the evaluation.",
105 "Good performance with room for minor improvements.",
106 "Satisfactory results with some areas for optimization.",
107 "Adequate performance but needs significant improvements.",
108 "Performance below expectations, requires major adjustments.",
109 ];
110 let reason = reasons[rng.random_range(0..reasons.len())];
111
112 if let Some(candidate) = cloned_response.candidates.get_mut(0) {
114 if let Some(part) = candidate.content.parts.get_mut(0) {
115 part.data = DataNum::Text(format!(
116 "{{\"score\": {}, \"reason\": \"{}\"}}",
117 score, reason
118 ));
119 }
120 }
121
122 cloned_response
123}
124
125pub struct LLMApiMock {
126 pub url: String,
127 pub server: mockito::ServerGuard,
128}
129
130impl LLMApiMock {
131 pub fn new() -> Self {
132 let mut server = mockito::Server::new();
133 let openai_embedding_response: OpenAIEmbeddingResponse =
135 serde_json::from_str(OPENAI_EMBEDDING_RESPONSE).unwrap();
136 let chat_msg_response: OpenAIChatResponse =
137 serde_json::from_str(OPENAI_CHAT_COMPLETION_RESPONSE).unwrap();
138 let chat_structured_response: OpenAIChatResponse =
139 serde_json::from_str(OPENAI_CHAT_STRUCTURED_RESPONSE).unwrap();
140 let chat_structured_score_response: OpenAIChatResponse =
141 serde_json::from_str(OPENAI_CHAT_STRUCTURED_SCORE_RESPONSE).unwrap();
142 let chat_structured_response_params: OpenAIChatResponse =
143 serde_json::from_str(OPENAI_CHAT_STRUCTURED_RESPONSE_PARAMS).unwrap();
144 let chat_structured_task_output: OpenAIChatResponse =
145 serde_json::from_str(OPENAI_CHAT_STRUCTURED_TASK_OUTPUT).unwrap();
146
147 let gemini_chat_response: GenerateContentResponse =
149 serde_json::from_str(GEMINI_CHAT_COMPLETION_RESPONSE).unwrap();
150 let gemini_chat_response_with_score: GenerateContentResponse =
151 serde_json::from_str(GEMINI_CHAT_COMPLETION_RESPONSE_WITH_SCORE).unwrap();
152 let gemini_embedding_response: GeminiEmbeddingResponse =
153 serde_json::from_str(GEMINI_EMBEDDING_RESPONSE).unwrap();
154
155 let anthropic_message_response: AnthropicMessageResponse =
157 serde_json::from_str(ANTHROPIC_MESSAGE_RESPONSE).unwrap();
158
159 let anthropic_message_structured_response: AnthropicMessageResponse =
160 serde_json::from_str(ANTHROPIC_MESSAGE_STRUCTURED_RESPONSE).unwrap();
161
162 let anthropic_message_structured_task_output: AnthropicMessageResponse =
163 serde_json::from_str(ANTHROPIC_MESSAGE_STRUCTURED_TASK_OUTPUT).unwrap();
164
165 server
166 .mock("POST", "/chat/completions")
167 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
168 "response_format": {
169 "type": "json_schema",
170 "json_schema": {
171 "name": "Parameters",
172 "schema": {
173 "$schema": "https://json-schema.org/draft/2020-12/schema",
174 "properties": {
175 "variable1": {
176 "format": "int32",
177 "type": "integer"
178 },
179 "variable2": {
180 "format": "int32",
181 "type": "integer"
182 }
183 },
184 "required": [
185 "variable1",
186 "variable2"
187 ],
188 "title": "Parameters",
189 "type": "object"
190 },
191 "strict": true
192 }
193
194 }
195 })))
196 .expect(usize::MAX)
197 .with_status(200)
198 .with_header("content-type", "application/json")
199 .with_body(serde_json::to_string(&chat_structured_response_params).unwrap())
200 .create();
201
202 server
203 .mock("POST", "/chat/completions")
204 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
205 "response_format": {
206 "type": "json_schema",
207 "json_schema": {
208 "name": "TaskOutput",
209 }
210 }
211 })))
212 .expect(usize::MAX)
213 .with_status(200)
214 .with_header("content-type", "application/json")
215 .with_body(serde_json::to_string(&chat_structured_task_output).unwrap())
216 .create();
217
218 server
219 .mock("POST", "/chat/completions")
220 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
221 "response_format": {
222 "type": "json_schema",
223 "json_schema": {
224 "name": "Score",
225 }
226 }
227 })))
228 .expect(usize::MAX)
229 .with_status(200)
230 .with_header("content-type", "application/json")
231 .with_body(serde_json::to_string(&chat_structured_score_response).unwrap())
232 .with_body_from_request({
233 let chat_structured_score_response = chat_structured_score_response.clone();
234 move |_request| {
235 let randomized_response = randomize_structured_openai_score_response(
236 &chat_structured_score_response.clone(),
237 );
238 serde_json::to_string(&randomized_response).unwrap().into()
239 }
240 })
241 .create();
242
243 server
244 .mock("POST", "/chat/completions")
245 .match_body(mockito::Matcher::Regex(
246 r#".*"name"\s*:\s*"Score".*"#.to_string(),
247 ))
248 .expect(usize::MAX)
249 .with_status(200)
250 .with_header("content-type", "application/json")
251 .with_body_from_request({
252 let chat_structured_score_response = chat_structured_score_response.clone();
253 move |_request| {
254 let randomized_response = randomize_structured_openai_score_response(
255 &chat_structured_score_response.clone(),
256 );
257 serde_json::to_string(&randomized_response).unwrap().into()
258 }
259 })
260 .create();
261
262 server
263 .mock("POST", "/chat/completions")
264 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
265 "response_format": {
266 "type": "json_schema"
267 }
268 })))
269 .expect(usize::MAX)
270 .with_status(200)
271 .with_header("content-type", "application/json")
272 .with_body(serde_json::to_string(&chat_structured_response).unwrap())
273 .create();
274
275 server
277 .mock(
278 "POST",
279 mockito::Matcher::Regex(r".*/.*:generateContent$".to_string()),
280 )
281 .match_header("x-goog-api-key", mockito::Matcher::Any)
282 .match_header("content-type", "application/json")
283 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
284 "contents": [
285 {
286 "parts": [
287 {
288 "text": "You are a helpful assistant"
289 }
290 ]
291 }
292 ]
293 })))
294 .expect(usize::MAX) .with_status(200)
296 .with_header("content-type", "application/json")
297 .with_body(serde_json::to_string(&gemini_chat_response).unwrap())
298 .create();
299
300 server
302 .mock(
303 "POST",
304 mockito::Matcher::Regex(r".*/.*:generateContent$".to_string()),
305 )
306 .match_header("x-goog-api-key", mockito::Matcher::Any)
307 .match_header("content-type", "application/json")
308 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
309 "generation_config": {
310 "responseMimeType": "application/json"
311 }
312 })))
313 .expect(usize::MAX)
314 .with_status(200)
315 .with_header("content-type", "application/json")
316 .with_body_from_request(move |_request| {
317 let randomized_response =
318 randomize_gemini_score_response(gemini_chat_response_with_score.clone());
319 serde_json::to_string(&randomized_response).unwrap().into()
320 })
321 .create();
322
323 server
325 .mock("POST", "/chat/completions")
326 .expect(usize::MAX)
327 .with_status(200)
328 .with_header("content-type", "application/json")
329 .with_body(serde_json::to_string(&chat_msg_response).unwrap())
330 .create();
331
332 server
333 .mock("POST", "/embeddings")
334 .expect(usize::MAX)
335 .with_status(200)
336 .with_header("content-type", "application/json")
337 .with_body_from_request(move |_request| {
338 let randomized_response =
339 randomize_openai_embedding_response(openai_embedding_response.clone());
340 serde_json::to_string(&randomized_response).unwrap().into()
341 })
342 .create();
343
344 server
345 .mock(
346 "POST",
347 mockito::Matcher::Regex(r".*/.*:embedContent$".to_string()),
348 )
349 .expect(usize::MAX)
350 .with_status(200)
351 .with_header("content-type", "application/json")
352 .with_body_from_request(move |_request| {
353 let randomized_response =
354 randomize_gemini_embedding_response(gemini_embedding_response.clone());
355 serde_json::to_string(&randomized_response).unwrap().into()
356 })
357 .create();
358
359 server
362 .mock("POST", "/messages")
363 .match_header("content-type", "application/json")
364 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
365 "messages": [
366 {
367 "content": [
368 {
369 "text": "Give me a score!",
370 "type": "text"
371 }
372 ]
373 }
374 ]
375 })))
376 .expect(usize::MAX) .with_status(200)
378 .with_header("content-type", "application/json")
379 .with_body(serde_json::to_string(&anthropic_message_structured_response).unwrap())
380 .create();
381
382 server
383 .mock("POST", "/messages")
384 .match_header("content-type", "application/json")
385 .match_body(mockito::Matcher::Regex(
386 r#".*"text"\s*:\s*"Give me a task list!".*"#.to_string(),
387 ))
388 .expect(usize::MAX) .with_status(200)
390 .with_header("content-type", "application/json")
391 .with_body(serde_json::to_string(&anthropic_message_structured_task_output).unwrap())
392 .create();
393
394 server
395 .mock("POST", "/messages")
396 .expect(usize::MAX)
397 .with_status(200)
398 .with_header("content-type", "application/json")
399 .with_body(serde_json::to_string(&anthropic_message_response).unwrap())
400 .create();
401
402 Self {
403 url: server.url(),
404 server,
405 }
406 }
407}
408
409impl Default for LLMApiMock {
410 fn default() -> Self {
411 Self::new()
412 }
413}
414
415#[pyclass]
416#[allow(dead_code)]
417pub struct LLMTestServer {
418 openai_server: Option<LLMApiMock>,
419
420 #[pyo3(get)]
421 pub url: Option<String>,
422}
423
424#[pymethods]
425impl LLMTestServer {
426 #[new]
427 pub fn new() -> Self {
428 LLMTestServer {
429 openai_server: None,
430 url: None,
431 }
432 }
433
434 pub fn start_mock_server(&mut self) -> Result<(), MockError> {
435 let llm_server = LLMApiMock::new();
436 println!("Mock LLM Server started at {}", llm_server.url);
437 self.openai_server = Some(llm_server);
438 Ok(())
439 }
440
441 pub fn stop_mock_server(&mut self) {
442 if let Some(server) = self.openai_server.take() {
443 drop(server);
444 std::env::remove_var("OPENAI_API_URL");
445 std::env::remove_var("OPENAI_API_KEY");
446 std::env::remove_var("GEMINI_API_KEY");
447 std::env::remove_var("GEMINI_API_URL");
448 std::env::remove_var("GOOGLE_API_KEY");
449 std::env::remove_var("GOOGLE_API_URL");
450 std::env::remove_var("ANTHROPIC_API_KEY");
451 std::env::remove_var("ANTHROPIC_API_URL");
452 }
453 println!("Mock LLM Server stopped");
454 }
455
456 pub fn set_env_vars_for_client(&self) -> Result<(), MockError> {
457 {
458 std::env::set_var("APP_ENV", "dev_client");
459 std::env::set_var("OPENAI_API_KEY", "test_key");
460 std::env::set_var("GEMINI_API_KEY", "gemini");
461 std::env::set_var("GOOGLE_API_KEY", "google");
462 std::env::set_var("ANTHROPIC_API_KEY", "anthropic_key");
463 std::env::set_var(
464 "OPENAI_API_URL",
465 self.openai_server.as_ref().unwrap().url.clone(),
466 );
467 std::env::set_var(
468 "GEMINI_API_URL",
469 self.openai_server.as_ref().unwrap().url.clone(),
470 );
471 std::env::set_var(
472 "GOOGLE_API_URL",
473 self.openai_server.as_ref().unwrap().url.clone(),
474 );
475 std::env::set_var(
476 "ANTHROPIC_API_URL",
477 self.openai_server.as_ref().unwrap().url.clone(),
478 );
479
480 Ok(())
481 }
482 }
483
484 pub fn start_server(&mut self) -> Result<(), MockError> {
485 self.cleanup()?;
486
487 println!("Starting Mock GenAI Server...");
488 self.start_mock_server()?;
489 self.set_env_vars_for_client()?;
490
491 std::env::set_var("APP_ENV", "dev_server");
493
494 self.url = Some(self.openai_server.as_ref().unwrap().url.clone());
495
496 Ok(())
497 }
498
499 pub fn stop_server(&mut self) -> Result<(), MockError> {
500 self.cleanup()?;
501
502 Ok(())
503 }
504
505 pub fn remove_env_vars_for_client(&self) -> Result<(), MockError> {
506 std::env::remove_var("OPENAI_API_URI");
507 std::env::remove_var("OPENAI_API_KEY");
508 std::env::remove_var("GEMINI_API_KEY");
509 std::env::remove_var("GEMINI_API_URL");
510 std::env::remove_var("GOOGLE_API_KEY");
511 std::env::remove_var("GOOGLE_API_URL");
512 std::env::remove_var("ANTHROPIC_API_KEY");
513 std::env::remove_var("ANTHROPIC_API_URL");
514 Ok(())
515 }
516
517 fn cleanup(&self) -> Result<(), MockError> {
518 self.remove_env_vars_for_client()?;
520
521 Ok(())
522 }
523
524 fn __enter__(mut self_: PyRefMut<Self>) -> Result<PyRefMut<Self>, MockError> {
525 self_.start_server()?;
526
527 Ok(self_)
528 }
529
530 fn __exit__(
531 &mut self,
532 _exc_type: Py<PyAny>,
533 _exc_value: Py<PyAny>,
534 _traceback: Py<PyAny>,
535 ) -> Result<(), MockError> {
536 self.stop_server()
537 }
538}
539
540impl Default for LLMTestServer {
541 fn default() -> Self {
542 Self::new()
543 }
544}
545
546#[allow(clippy::uninlined_format_args)]
547pub fn create_score_prompt(params: Option<Vec<String>>) -> Prompt {
548 let mut user_prompt = "What is the score?".to_string();
549
550 if let Some(params) = params {
552 for param in params {
553 user_prompt.push_str(&format!(" ${{{}}}", param));
554 }
555 }
556
557 let system_content = "You are a helpful assistant.".to_string();
558
559 let system_msg = ChatMessage {
560 role: Role::Developer.to_string(),
561 content: vec![ContentPart::Text(TextContentPart::new(system_content))],
562 name: None,
563 };
564
565 let user_msg = ChatMessage {
566 role: Role::User.to_string(),
567 content: vec![ContentPart::Text(TextContentPart::new(user_prompt))],
568 name: None,
569 };
570 Prompt::new_rs(
571 vec![MessageNum::OpenAIMessageV1(user_msg)],
572 "gpt-4o",
573 potato_type::Provider::OpenAI,
574 vec![MessageNum::OpenAIMessageV1(system_msg)],
575 None,
576 Some(Score::get_structured_output_schema()),
577 potato_type::prompt::ResponseType::Score,
578 )
579 .unwrap()
580}
581
582pub fn create_parameterized_prompt() -> Prompt {
583 let user_content = "What is ${variable1} + ${variable2}?".to_string();
584 let system_content = "You are a helpful assistant.".to_string();
585
586 let system_msg = ChatMessage {
587 role: Role::Developer.to_string(),
588 content: vec![ContentPart::Text(TextContentPart::new(system_content))],
589 name: None,
590 };
591
592 let user_msg = ChatMessage {
593 role: Role::User.to_string(),
594 content: vec![ContentPart::Text(TextContentPart::new(user_content))],
595 name: None,
596 };
597 Prompt::new_rs(
598 vec![MessageNum::OpenAIMessageV1(user_msg)],
599 "gpt-4o",
600 potato_type::Provider::OpenAI,
601 vec![MessageNum::OpenAIMessageV1(system_msg)],
602 None,
603 None,
604 potato_type::prompt::ResponseType::Null,
605 )
606 .unwrap()
607}