baked_potato/
mock.rs

1use crate::error::MockError;
2use mockito;
3use potato_provider::{GenerateContentResponse, OpenAIChatResponse};
4use potato_type::google::GeminiEmbeddingResponse;
5use potato_type::openai::embedding::OpenAIEmbeddingResponse;
6use pyo3::prelude::*;
7use rand::Rng;
8use serde_json;
9
10pub const OPENAI_EMBEDDING_RESPONSE: &str = include_str!("assets/openai/embedding_response.json");
11
12pub const GEMINI_EMBEDDING_RESPONSE: &str = include_str!("assets/gemini/embedding_response.json");
13
14pub const OPENAI_CHAT_COMPLETION_RESPONSE: &str =
15    include_str!("assets/openai/openai_chat_completion_response.json");
16
17pub const OPENAI_CHAT_STRUCTURED_RESPONSE: &str =
18    include_str!("assets/openai/chat_completion_structured_response.json");
19
20pub const OPENAI_CHAT_STRUCTURED_SCORE_RESPONSE: &str =
21    include_str!("assets/openai/chat_completion_structured_score_response.json");
22
23pub const OPENAI_CHAT_STRUCTURED_RESPONSE_PARAMS: &str =
24    include_str!("assets/openai/chat_completion_structured_response_params.json");
25
26pub const OPENAI_CHAT_STRUCTURED_TASK_OUTPUT: &str =
27    include_str!("assets/openai/chat_completion_structured_task_output.json");
28
29pub const GEMINI_CHAT_COMPLETION_RESPONSE: &str =
30    include_str!("assets/gemini/chat_completion.json");
31
32pub const GEMINI_CHAT_COMPLETION_RESPONSE_WITH_SCORE: &str =
33    include_str!("assets/gemini/chat_completion_with_score.json");
34
35fn randomize_openai_embedding_response(
36    response: OpenAIEmbeddingResponse,
37) -> OpenAIEmbeddingResponse {
38    // create random Vec<f32> of length 512
39    let mut cloned_response = response.clone();
40    let mut rng = rand::rng();
41    let embedding: Vec<f32> = (0..512).map(|_| rng.random_range(-1.0..1.0)).collect();
42    cloned_response.data[0].embedding = embedding;
43    cloned_response
44}
45
46fn randomize_gemini_embedding_response(
47    response: GeminiEmbeddingResponse,
48) -> GeminiEmbeddingResponse {
49    let mut cloned_response = response.clone();
50    let mut rng = rand::rng();
51    let embedding: Vec<f32> = (0..512).map(|_| rng.random_range(-1.0..1.0)).collect();
52    cloned_response.embedding.values = embedding;
53    cloned_response
54}
55
56fn randomize_structured_openai_score_response(response: &OpenAIChatResponse) -> OpenAIChatResponse {
57    let mut cloned_response = response.clone();
58    let mut rng = rand::rng();
59
60    // Generate random score between 1 and 5
61    let score = rng.random_range(1..=5);
62
63    // Generate random reason from a set of predefined reasons
64    let reasons = [
65        "The code is excellent and follows best practices.",
66        "The implementation is solid with minor improvements possible.",
67        "The code works but could use some optimization.",
68        "The solution is functional but needs refactoring.",
69        "The code has significant issues that need addressing.",
70    ];
71    let reason = reasons[rng.random_range(0..reasons.len())];
72
73    cloned_response.choices[0].message.content = Some(format!(
74        "{{ \"score\": {}, \"reason\": \"{}\" }}",
75        score, reason
76    ));
77
78    cloned_response
79}
80
81fn randomize_gemini_score_response(response: GenerateContentResponse) -> GenerateContentResponse {
82    let mut cloned_response = response.clone();
83    let mut rng = rand::rng();
84
85    // Generate random score between 1 and 100 (typical for Gemini scoring)
86    let score = rng.random_range(1..=100);
87
88    // Generate random reason from a set of predefined reasons
89    let reasons = [
90        "The model performed exceptionally well on the evaluation.",
91        "Good performance with room for minor improvements.",
92        "Satisfactory results with some areas for optimization.",
93        "Adequate performance but needs significant improvements.",
94        "Performance below expectations, requires major adjustments.",
95    ];
96    let reason = reasons[rng.random_range(0..reasons.len())];
97
98    // Update the first candidate's content
99    if let Some(candidate) = cloned_response.candidates.get_mut(0) {
100        if let Some(part) = candidate.content.parts.get_mut(0) {
101            part.text = Some(format!(
102                "{{\"score\": {}, \"reason\": \"{}\"}}",
103                score, reason
104            ));
105        }
106    }
107
108    cloned_response
109}
110
111pub struct LLMApiMock {
112    pub url: String,
113    pub server: mockito::ServerGuard,
114}
115
116impl LLMApiMock {
117    pub fn new() -> Self {
118        let mut server = mockito::Server::new();
119        // load the OpenAI chat completion response
120        let openai_embedding_response: OpenAIEmbeddingResponse =
121            serde_json::from_str(OPENAI_EMBEDDING_RESPONSE).unwrap();
122        let chat_msg_response: OpenAIChatResponse =
123            serde_json::from_str(OPENAI_CHAT_COMPLETION_RESPONSE).unwrap();
124        let chat_structured_response: OpenAIChatResponse =
125            serde_json::from_str(OPENAI_CHAT_STRUCTURED_RESPONSE).unwrap();
126        let chat_structured_score_response: OpenAIChatResponse =
127            serde_json::from_str(OPENAI_CHAT_STRUCTURED_SCORE_RESPONSE).unwrap();
128        let chat_structured_response_params: OpenAIChatResponse =
129            serde_json::from_str(OPENAI_CHAT_STRUCTURED_RESPONSE_PARAMS).unwrap();
130        let chat_structured_task_output: OpenAIChatResponse =
131            serde_json::from_str(OPENAI_CHAT_STRUCTURED_TASK_OUTPUT).unwrap();
132
133        // load the Gemini chat completion response
134        let gemini_chat_response: GenerateContentResponse =
135            serde_json::from_str(GEMINI_CHAT_COMPLETION_RESPONSE).unwrap();
136        let gemini_chat_response_with_score: GenerateContentResponse =
137            serde_json::from_str(GEMINI_CHAT_COMPLETION_RESPONSE_WITH_SCORE).unwrap();
138        let gemini_embedding_response: GeminiEmbeddingResponse =
139            serde_json::from_str(GEMINI_EMBEDDING_RESPONSE).unwrap();
140
141        server
142            .mock("POST", "/chat/completions")
143            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
144                "response_format": {
145                    "type": "json_schema",
146                    "json_schema": {
147                        "name": "Parameters",
148                         "schema": {
149                              "$schema": "https://json-schema.org/draft/2020-12/schema",
150                              "properties": {
151                                  "variable1": {
152                                  "format": "int32",
153                                  "type": "integer"
154                                  },
155                                  "variable2": {
156                                  "format": "int32",
157                                  "type": "integer"
158                                  }
159                              },
160                              "required": [
161                                  "variable1",
162                                  "variable2"
163                              ],
164                              "title": "Parameters",
165                              "type": "object"
166                              },
167                        "strict": true
168                    }
169
170                }
171            })))
172            .expect(usize::MAX)
173            .with_status(200)
174            .with_header("content-type", "application/json")
175            .with_body(serde_json::to_string(&chat_structured_response_params).unwrap())
176            .create();
177
178        server
179            .mock("POST", "/chat/completions")
180            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
181               "response_format": {
182                    "type": "json_schema",
183                    "json_schema": {
184                        "name": "TaskOutput",
185                    }
186                }
187            })))
188            .expect(usize::MAX)
189            .with_status(200)
190            .with_header("content-type", "application/json")
191            .with_body(serde_json::to_string(&chat_structured_task_output).unwrap())
192            .create();
193
194        server
195            .mock("POST", "/chat/completions")
196            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
197                "response_format": {
198                    "type": "json_schema",
199                    "json_schema": {
200                        "name": "Score",
201                    }
202                }
203            })))
204            .expect(usize::MAX)
205            .with_status(200)
206            .with_header("content-type", "application/json")
207            .with_body(serde_json::to_string(&chat_structured_score_response).unwrap())
208            .with_body_from_request({
209                let chat_structured_score_response = chat_structured_score_response.clone();
210                move |_request| {
211                    let randomized_response = randomize_structured_openai_score_response(
212                        &chat_structured_score_response.clone(),
213                    );
214                    serde_json::to_string(&randomized_response).unwrap().into()
215                }
216            })
217            .create();
218
219        server
220            .mock("POST", "/chat/completions")
221            .match_body(mockito::Matcher::Regex(
222                r#".*"name"\s*:\s*"Score".*"#.to_string(),
223            ))
224            .expect(usize::MAX)
225            .with_status(200)
226            .with_header("content-type", "application/json")
227            .with_body_from_request({
228                let chat_structured_score_response = chat_structured_score_response.clone();
229                move |_request| {
230                    let randomized_response = randomize_structured_openai_score_response(
231                        &chat_structured_score_response.clone(),
232                    );
233                    serde_json::to_string(&randomized_response).unwrap().into()
234                }
235            })
236            .create();
237
238        server
239            .mock("POST", "/chat/completions")
240            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
241                "response_format": {
242                    "type": "json_schema"
243                }
244            })))
245            .expect(usize::MAX)
246            .with_status(200)
247            .with_header("content-type", "application/json")
248            .with_body(serde_json::to_string(&chat_structured_response).unwrap())
249            .create();
250
251        // mock the Gemini chat completion response
252        server
253            .mock(
254                "POST",
255                mockito::Matcher::Regex(r".*/.*:generateContent$".to_string()),
256            )
257            .match_header("x-goog-api-key", mockito::Matcher::Any)
258            .match_header("content-type", "application/json")
259            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
260                "contents": [
261                    {
262                        "parts": [
263                            {
264                                "text":  "You are a helpful assistant"
265                            }
266                        ]
267                    }
268                ]
269            })))
270            .expect(usize::MAX) // More specific expectation than usize::MAX
271            .with_status(200)
272            .with_header("content-type", "application/json")
273            .with_body(serde_json::to_string(&gemini_chat_response).unwrap())
274            .create();
275
276        // mock structured response
277        server
278            .mock(
279                "POST",
280                mockito::Matcher::Regex(r".*/.*:generateContent$".to_string()),
281            )
282            .match_header("x-goog-api-key", mockito::Matcher::Any)
283            .match_header("content-type", "application/json")
284            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
285                "generation_config": {
286                    "responseMimeType": "application/json"
287                }
288            })))
289            .expect(usize::MAX)
290            .with_status(200)
291            .with_header("content-type", "application/json")
292            .with_body_from_request(move |_request| {
293                let randomized_response =
294                    randomize_gemini_score_response(gemini_chat_response_with_score.clone());
295                serde_json::to_string(&randomized_response).unwrap().into()
296            })
297            .create();
298
299        // Openai chat completion mock
300        server
301            .mock("POST", "/chat/completions")
302            .expect(usize::MAX)
303            .with_status(200)
304            .with_header("content-type", "application/json")
305            .with_body(serde_json::to_string(&chat_msg_response).unwrap())
306            .create();
307
308        server
309            .mock("POST", "/embeddings")
310            .expect(usize::MAX)
311            .with_status(200)
312            .with_header("content-type", "application/json")
313            .with_body_from_request(move |_request| {
314                let randomized_response =
315                    randomize_openai_embedding_response(openai_embedding_response.clone());
316                serde_json::to_string(&randomized_response).unwrap().into()
317            })
318            .create();
319
320        server
321            .mock(
322                "POST",
323                mockito::Matcher::Regex(r".*/.*:embedContent$".to_string()),
324            )
325            .expect(usize::MAX)
326            .with_status(200)
327            .with_header("content-type", "application/json")
328            .with_body_from_request(move |_request| {
329                let randomized_response =
330                    randomize_gemini_embedding_response(gemini_embedding_response.clone());
331                serde_json::to_string(&randomized_response).unwrap().into()
332            })
333            .create();
334
335        Self {
336            url: server.url(),
337            server,
338        }
339    }
340}
341
342impl Default for LLMApiMock {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348#[pyclass]
349#[allow(dead_code)]
350pub struct LLMTestServer {
351    openai_server: Option<LLMApiMock>,
352
353    #[pyo3(get)]
354    pub url: Option<String>,
355}
356
357#[pymethods]
358impl LLMTestServer {
359    #[new]
360    pub fn new() -> Self {
361        LLMTestServer {
362            openai_server: None,
363            url: None,
364        }
365    }
366
367    pub fn start_mock_server(&mut self) -> Result<(), MockError> {
368        let llm_server = LLMApiMock::new();
369        println!("Mock LLM Server started at {}", llm_server.url);
370        self.openai_server = Some(llm_server);
371        Ok(())
372    }
373
374    pub fn stop_mock_server(&mut self) {
375        if let Some(server) = self.openai_server.take() {
376            drop(server);
377            std::env::remove_var("OPENAI_API_URL");
378            std::env::remove_var("OPENAI_API_KEY");
379        }
380        println!("Mock LLM Server stopped");
381    }
382
383    pub fn set_env_vars_for_client(&self) -> Result<(), MockError> {
384        {
385            std::env::set_var("APP_ENV", "dev_client");
386            std::env::set_var("OPENAI_API_KEY", "test_key");
387            std::env::set_var("GEMINI_API_KEY", "gemini");
388            std::env::set_var(
389                "OPENAI_API_URL",
390                self.openai_server.as_ref().unwrap().url.clone(),
391            );
392            std::env::set_var(
393                "GEMINI_API_URL",
394                self.openai_server.as_ref().unwrap().url.clone(),
395            );
396
397            Ok(())
398        }
399    }
400
401    pub fn start_server(&mut self) -> Result<(), MockError> {
402        self.cleanup()?;
403
404        println!("Starting Mock GenAI Server...");
405        self.start_mock_server()?;
406        self.set_env_vars_for_client()?;
407
408        // set server env vars
409        std::env::set_var("APP_ENV", "dev_server");
410
411        self.url = Some(self.openai_server.as_ref().unwrap().url.clone());
412
413        Ok(())
414    }
415
416    pub fn stop_server(&mut self) -> Result<(), MockError> {
417        self.cleanup()?;
418
419        Ok(())
420    }
421
422    pub fn remove_env_vars_for_client(&self) -> Result<(), MockError> {
423        std::env::remove_var("OPENAI_API_URI");
424        std::env::remove_var("OPENAI_API_KEY");
425        std::env::remove_var("GEMINI_API_KEY");
426        std::env::remove_var("GEMINI_API_URL");
427        Ok(())
428    }
429
430    fn cleanup(&self) -> Result<(), MockError> {
431        // unset env vars
432        self.remove_env_vars_for_client()?;
433
434        Ok(())
435    }
436
437    fn __enter__(mut self_: PyRefMut<Self>) -> Result<PyRefMut<Self>, MockError> {
438        self_.start_server()?;
439
440        Ok(self_)
441    }
442
443    fn __exit__(
444        &mut self,
445        _exc_type: Py<PyAny>,
446        _exc_value: Py<PyAny>,
447        _traceback: Py<PyAny>,
448    ) -> Result<(), MockError> {
449        self.stop_server()
450    }
451}
452
453impl Default for LLMTestServer {
454    fn default() -> Self {
455        Self::new()
456    }
457}