Skip to main content

baked_potato/
mock.rs

1use crate::error::MockError;
2use mockito;
3use potato_type::anthropic::AnthropicMessageResponse;
4use potato_type::google::v1::generate::{DataNum, GenerateContentResponse};
5use potato_type::google::GeminiEmbeddingResponse;
6use potato_type::openai::v1::embedding::OpenAIEmbeddingResponse;
7use potato_type::openai::v1::OpenAIChatResponse;
8use potato_type::openai::{ChatMessage, ContentPart, TextContentPart};
9use potato_type::prompt::{MessageNum, Prompt, Role, Score};
10use potato_type::StructuredOutput;
11use pyo3::prelude::*;
12use rand::Rng;
13use serde_json;
14
15pub const OPENAI_EMBEDDING_RESPONSE: &str = include_str!("assets/openai/embedding_response.json");
16
17pub const GEMINI_EMBEDDING_RESPONSE: &str = include_str!("assets/gemini/embedding_response.json");
18
19pub const OPENAI_CHAT_COMPLETION_RESPONSE: &str =
20    include_str!("assets/openai/openai_chat_completion_response.json");
21
22pub const OPENAI_CHAT_STRUCTURED_RESPONSE: &str =
23    include_str!("assets/openai/chat_completion_structured_response.json");
24
25pub const OPENAI_CHAT_STRUCTURED_SCORE_RESPONSE: &str =
26    include_str!("assets/openai/chat_completion_structured_score_response.json");
27
28pub const OPENAI_CHAT_STRUCTURED_RESPONSE_PARAMS: &str =
29    include_str!("assets/openai/chat_completion_structured_response_params.json");
30
31pub const OPENAI_CHAT_STRUCTURED_TASK_OUTPUT: &str =
32    include_str!("assets/openai/chat_completion_structured_task_output.json");
33
34pub const GEMINI_CHAT_COMPLETION_RESPONSE: &str =
35    include_str!("assets/gemini/chat_completion.json");
36
37pub const GEMINI_CHAT_COMPLETION_RESPONSE_WITH_SCORE: &str =
38    include_str!("assets/gemini/chat_completion_with_score.json");
39
40pub const ANTHROPIC_MESSAGE_RESPONSE: &str =
41    include_str!("assets/anthropic/message_completion.json");
42
43pub const ANTHROPIC_MESSAGE_STRUCTURED_RESPONSE: &str =
44    include_str!("assets/anthropic/message_structured_completion.json");
45
46pub const ANTHROPIC_MESSAGE_STRUCTURED_TASK_OUTPUT: &str =
47    include_str!("assets/anthropic/message_structured_completion_tasks.json");
48
49fn randomize_openai_embedding_response(
50    response: OpenAIEmbeddingResponse,
51) -> OpenAIEmbeddingResponse {
52    // create random Vec<f32> of length 512
53    let mut cloned_response = response.clone();
54    let mut rng = rand::rng();
55    let embedding: Vec<f32> = (0..512).map(|_| rng.random_range(-1.0..1.0)).collect();
56    cloned_response.data[0].embedding = embedding;
57    cloned_response
58}
59
60fn randomize_gemini_embedding_response(
61    response: GeminiEmbeddingResponse,
62) -> GeminiEmbeddingResponse {
63    let mut cloned_response = response.clone();
64    let mut rng = rand::rng();
65    let embedding: Vec<f32> = (0..512).map(|_| rng.random_range(-1.0..1.0)).collect();
66    cloned_response.embedding.values = embedding;
67    cloned_response
68}
69
70fn randomize_structured_openai_score_response(response: &OpenAIChatResponse) -> OpenAIChatResponse {
71    let mut cloned_response = response.clone();
72    let mut rng = rand::rng();
73
74    // Generate random score between 1 and 5
75    let score = rng.random_range(1..=5);
76
77    // Generate random reason from a set of predefined reasons
78    let reasons = [
79        "The code is excellent and follows best practices.",
80        "The implementation is solid with minor improvements possible.",
81        "The code works but could use some optimization.",
82        "The solution is functional but needs refactoring.",
83        "The code has significant issues that need addressing.",
84    ];
85    let reason = reasons[rng.random_range(0..reasons.len())];
86
87    cloned_response.choices[0].message.content = Some(format!(
88        "{{ \"score\": {}, \"reason\": \"{}\" }}",
89        score, reason
90    ));
91
92    cloned_response
93}
94
95fn randomize_gemini_score_response(response: GenerateContentResponse) -> GenerateContentResponse {
96    let mut cloned_response = response.clone();
97    let mut rng = rand::rng();
98
99    // Generate random score between 1 and 5 (typical for Gemini scoring)
100    let score = rng.random_range(1..=5);
101
102    // Generate random reason from a set of predefined reasons
103    let reasons = [
104        "The model performed exceptionally well on the evaluation.",
105        "Good performance with room for minor improvements.",
106        "Satisfactory results with some areas for optimization.",
107        "Adequate performance but needs significant improvements.",
108        "Performance below expectations, requires major adjustments.",
109    ];
110    let reason = reasons[rng.random_range(0..reasons.len())];
111
112    // Update the first candidate's content
113    if let Some(candidate) = cloned_response.candidates.get_mut(0) {
114        if let Some(part) = candidate.content.parts.get_mut(0) {
115            part.data = DataNum::Text(format!(
116                "{{\"score\": {}, \"reason\": \"{}\"}}",
117                score, reason
118            ));
119        }
120    }
121
122    cloned_response
123}
124
125pub struct LLMApiMock {
126    pub url: String,
127    pub server: mockito::ServerGuard,
128}
129
130impl LLMApiMock {
131    pub fn new() -> Self {
132        let mut server = mockito::Server::new();
133        // load the OpenAI chat completion response
134        let openai_embedding_response: OpenAIEmbeddingResponse =
135            serde_json::from_str(OPENAI_EMBEDDING_RESPONSE).unwrap();
136        let chat_msg_response: OpenAIChatResponse =
137            serde_json::from_str(OPENAI_CHAT_COMPLETION_RESPONSE).unwrap();
138        let chat_structured_response: OpenAIChatResponse =
139            serde_json::from_str(OPENAI_CHAT_STRUCTURED_RESPONSE).unwrap();
140        let chat_structured_score_response: OpenAIChatResponse =
141            serde_json::from_str(OPENAI_CHAT_STRUCTURED_SCORE_RESPONSE).unwrap();
142        let chat_structured_response_params: OpenAIChatResponse =
143            serde_json::from_str(OPENAI_CHAT_STRUCTURED_RESPONSE_PARAMS).unwrap();
144        let chat_structured_task_output: OpenAIChatResponse =
145            serde_json::from_str(OPENAI_CHAT_STRUCTURED_TASK_OUTPUT).unwrap();
146
147        // load the Gemini chat completion response
148        let gemini_chat_response: GenerateContentResponse =
149            serde_json::from_str(GEMINI_CHAT_COMPLETION_RESPONSE).unwrap();
150        let gemini_chat_response_with_score: GenerateContentResponse =
151            serde_json::from_str(GEMINI_CHAT_COMPLETION_RESPONSE_WITH_SCORE).unwrap();
152        let gemini_embedding_response: GeminiEmbeddingResponse =
153            serde_json::from_str(GEMINI_EMBEDDING_RESPONSE).unwrap();
154
155        // anthropic message response
156        let anthropic_message_response: AnthropicMessageResponse =
157            serde_json::from_str(ANTHROPIC_MESSAGE_RESPONSE).unwrap();
158
159        let anthropic_message_structured_response: AnthropicMessageResponse =
160            serde_json::from_str(ANTHROPIC_MESSAGE_STRUCTURED_RESPONSE).unwrap();
161
162        let anthropic_message_structured_task_output: AnthropicMessageResponse =
163            serde_json::from_str(ANTHROPIC_MESSAGE_STRUCTURED_TASK_OUTPUT).unwrap();
164
165        server
166            .mock("POST", "/chat/completions")
167            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
168                "response_format": {
169                    "type": "json_schema",
170                    "json_schema": {
171                        "name": "Parameters",
172                         "schema": {
173                              "$schema": "https://json-schema.org/draft/2020-12/schema",
174                              "properties": {
175                                  "variable1": {
176                                  "format": "int32",
177                                  "type": "integer"
178                                  },
179                                  "variable2": {
180                                  "format": "int32",
181                                  "type": "integer"
182                                  }
183                              },
184                              "required": [
185                                  "variable1",
186                                  "variable2"
187                              ],
188                              "title": "Parameters",
189                              "type": "object"
190                              },
191                        "strict": true
192                    }
193
194                }
195            })))
196            .expect(usize::MAX)
197            .with_status(200)
198            .with_header("content-type", "application/json")
199            .with_body(serde_json::to_string(&chat_structured_response_params).unwrap())
200            .create();
201
202        server
203            .mock("POST", "/chat/completions")
204            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
205               "response_format": {
206                    "type": "json_schema",
207                    "json_schema": {
208                        "name": "TaskOutput",
209                    }
210                }
211            })))
212            .expect(usize::MAX)
213            .with_status(200)
214            .with_header("content-type", "application/json")
215            .with_body(serde_json::to_string(&chat_structured_task_output).unwrap())
216            .create();
217
218        server
219            .mock("POST", "/chat/completions")
220            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
221                "response_format": {
222                    "type": "json_schema",
223                    "json_schema": {
224                        "name": "Score",
225                    }
226                }
227            })))
228            .expect(usize::MAX)
229            .with_status(200)
230            .with_header("content-type", "application/json")
231            .with_body(serde_json::to_string(&chat_structured_score_response).unwrap())
232            .with_body_from_request({
233                let chat_structured_score_response = chat_structured_score_response.clone();
234                move |_request| {
235                    let randomized_response = randomize_structured_openai_score_response(
236                        &chat_structured_score_response.clone(),
237                    );
238                    serde_json::to_string(&randomized_response).unwrap().into()
239                }
240            })
241            .create();
242
243        server
244            .mock("POST", "/chat/completions")
245            .match_body(mockito::Matcher::Regex(
246                r#".*"name"\s*:\s*"Score".*"#.to_string(),
247            ))
248            .expect(usize::MAX)
249            .with_status(200)
250            .with_header("content-type", "application/json")
251            .with_body_from_request({
252                let chat_structured_score_response = chat_structured_score_response.clone();
253                move |_request| {
254                    let randomized_response = randomize_structured_openai_score_response(
255                        &chat_structured_score_response.clone(),
256                    );
257                    serde_json::to_string(&randomized_response).unwrap().into()
258                }
259            })
260            .create();
261
262        server
263            .mock("POST", "/chat/completions")
264            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
265                "response_format": {
266                    "type": "json_schema"
267                }
268            })))
269            .expect(usize::MAX)
270            .with_status(200)
271            .with_header("content-type", "application/json")
272            .with_body(serde_json::to_string(&chat_structured_response).unwrap())
273            .create();
274
275        // mock the Gemini chat completion response
276        server
277            .mock(
278                "POST",
279                mockito::Matcher::Regex(r".*/.*:generateContent$".to_string()),
280            )
281            .match_header("x-goog-api-key", mockito::Matcher::Any)
282            .match_header("content-type", "application/json")
283            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
284                "contents": [
285                    {
286                        "parts": [
287                            {
288                                "text":  "You are a helpful assistant"
289                            }
290                        ]
291                    }
292                ]
293            })))
294            .expect(usize::MAX) // More specific expectation than usize::MAX
295            .with_status(200)
296            .with_header("content-type", "application/json")
297            .with_body(serde_json::to_string(&gemini_chat_response).unwrap())
298            .create();
299
300        // mock structured response
301        server
302            .mock(
303                "POST",
304                mockito::Matcher::Regex(r".*/.*:generateContent$".to_string()),
305            )
306            .match_header("x-goog-api-key", mockito::Matcher::Any)
307            .match_header("content-type", "application/json")
308            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
309                "generation_config": {
310                    "responseMimeType": "application/json"
311                }
312            })))
313            .expect(usize::MAX)
314            .with_status(200)
315            .with_header("content-type", "application/json")
316            .with_body_from_request(move |_request| {
317                let randomized_response =
318                    randomize_gemini_score_response(gemini_chat_response_with_score.clone());
319                serde_json::to_string(&randomized_response).unwrap().into()
320            })
321            .create();
322
323        // Openai chat completion mock
324        server
325            .mock("POST", "/chat/completions")
326            .expect(usize::MAX)
327            .with_status(200)
328            .with_header("content-type", "application/json")
329            .with_body(serde_json::to_string(&chat_msg_response).unwrap())
330            .create();
331
332        server
333            .mock("POST", "/embeddings")
334            .expect(usize::MAX)
335            .with_status(200)
336            .with_header("content-type", "application/json")
337            .with_body_from_request(move |_request| {
338                let randomized_response =
339                    randomize_openai_embedding_response(openai_embedding_response.clone());
340                serde_json::to_string(&randomized_response).unwrap().into()
341            })
342            .create();
343
344        server
345            .mock(
346                "POST",
347                mockito::Matcher::Regex(r".*/.*:embedContent$".to_string()),
348            )
349            .expect(usize::MAX)
350            .with_status(200)
351            .with_header("content-type", "application/json")
352            .with_body_from_request(move |_request| {
353                let randomized_response =
354                    randomize_gemini_embedding_response(gemini_embedding_response.clone());
355                serde_json::to_string(&randomized_response).unwrap().into()
356            })
357            .create();
358
359        // mock the anthropic message response
360
361        server
362            .mock("POST", "/messages")
363            .match_header("content-type", "application/json")
364            .match_body(mockito::Matcher::PartialJson(serde_json::json!({
365                "messages": [
366                    {
367                        "content": [
368                            {
369                                "text":  "Give me a score!",
370                                "type": "text"
371                            }
372                        ]
373                    }
374                ]
375            })))
376            .expect(usize::MAX) // More specific expectation than usize::MAX
377            .with_status(200)
378            .with_header("content-type", "application/json")
379            .with_body(serde_json::to_string(&anthropic_message_structured_response).unwrap())
380            .create();
381
382        server
383            .mock("POST", "/messages")
384            .match_header("content-type", "application/json")
385            .match_body(mockito::Matcher::Regex(
386                r#".*"text"\s*:\s*"Give me a task list!".*"#.to_string(),
387            ))
388            .expect(usize::MAX) // More specific expectation than usize::MAX
389            .with_status(200)
390            .with_header("content-type", "application/json")
391            .with_body(serde_json::to_string(&anthropic_message_structured_task_output).unwrap())
392            .create();
393
394        server
395            .mock("POST", "/messages")
396            .expect(usize::MAX)
397            .with_status(200)
398            .with_header("content-type", "application/json")
399            .with_body(serde_json::to_string(&anthropic_message_response).unwrap())
400            .create();
401
402        Self {
403            url: server.url(),
404            server,
405        }
406    }
407}
408
409impl Default for LLMApiMock {
410    fn default() -> Self {
411        Self::new()
412    }
413}
414
415#[pyclass]
416#[allow(dead_code)]
417pub struct LLMTestServer {
418    openai_server: Option<LLMApiMock>,
419
420    #[pyo3(get)]
421    pub url: Option<String>,
422}
423
424#[pymethods]
425impl LLMTestServer {
426    #[new]
427    pub fn new() -> Self {
428        LLMTestServer {
429            openai_server: None,
430            url: None,
431        }
432    }
433
434    pub fn start_mock_server(&mut self) -> Result<(), MockError> {
435        let llm_server = LLMApiMock::new();
436        println!("Mock LLM Server started at {}", llm_server.url);
437        self.openai_server = Some(llm_server);
438        Ok(())
439    }
440
441    pub fn stop_mock_server(&mut self) {
442        if let Some(server) = self.openai_server.take() {
443            drop(server);
444            std::env::remove_var("OPENAI_API_URL");
445            std::env::remove_var("OPENAI_API_KEY");
446            std::env::remove_var("GEMINI_API_KEY");
447            std::env::remove_var("GEMINI_API_URL");
448            std::env::remove_var("GOOGLE_API_KEY");
449            std::env::remove_var("GOOGLE_API_URL");
450            std::env::remove_var("ANTHROPIC_API_KEY");
451            std::env::remove_var("ANTHROPIC_API_URL");
452        }
453        println!("Mock LLM Server stopped");
454    }
455
456    pub fn set_env_vars_for_client(&self) -> Result<(), MockError> {
457        {
458            std::env::set_var("APP_ENV", "dev_client");
459            std::env::set_var("OPENAI_API_KEY", "test_key");
460            std::env::set_var("GEMINI_API_KEY", "gemini");
461            std::env::set_var("GOOGLE_API_KEY", "google");
462            std::env::set_var("ANTHROPIC_API_KEY", "anthropic_key");
463            std::env::set_var(
464                "OPENAI_API_URL",
465                self.openai_server.as_ref().unwrap().url.clone(),
466            );
467            std::env::set_var(
468                "GEMINI_API_URL",
469                self.openai_server.as_ref().unwrap().url.clone(),
470            );
471            std::env::set_var(
472                "GOOGLE_API_URL",
473                self.openai_server.as_ref().unwrap().url.clone(),
474            );
475            std::env::set_var(
476                "ANTHROPIC_API_URL",
477                self.openai_server.as_ref().unwrap().url.clone(),
478            );
479
480            Ok(())
481        }
482    }
483
484    pub fn start_server(&mut self) -> Result<(), MockError> {
485        self.cleanup()?;
486
487        println!("Starting Mock GenAI Server...");
488        self.start_mock_server()?;
489        self.set_env_vars_for_client()?;
490
491        // set server env vars
492        std::env::set_var("APP_ENV", "dev_server");
493
494        self.url = Some(self.openai_server.as_ref().unwrap().url.clone());
495
496        Ok(())
497    }
498
499    pub fn stop_server(&mut self) -> Result<(), MockError> {
500        self.cleanup()?;
501
502        Ok(())
503    }
504
505    pub fn remove_env_vars_for_client(&self) -> Result<(), MockError> {
506        std::env::remove_var("OPENAI_API_URI");
507        std::env::remove_var("OPENAI_API_KEY");
508        std::env::remove_var("GEMINI_API_KEY");
509        std::env::remove_var("GEMINI_API_URL");
510        std::env::remove_var("GOOGLE_API_KEY");
511        std::env::remove_var("GOOGLE_API_URL");
512        std::env::remove_var("ANTHROPIC_API_KEY");
513        std::env::remove_var("ANTHROPIC_API_URL");
514        Ok(())
515    }
516
517    fn cleanup(&self) -> Result<(), MockError> {
518        // unset env vars
519        self.remove_env_vars_for_client()?;
520
521        Ok(())
522    }
523
524    fn __enter__(mut self_: PyRefMut<Self>) -> Result<PyRefMut<Self>, MockError> {
525        self_.start_server()?;
526
527        Ok(self_)
528    }
529
530    fn __exit__(
531        &mut self,
532        _exc_type: Py<PyAny>,
533        _exc_value: Py<PyAny>,
534        _traceback: Py<PyAny>,
535    ) -> Result<(), MockError> {
536        self.stop_server()
537    }
538}
539
540impl Default for LLMTestServer {
541    fn default() -> Self {
542        Self::new()
543    }
544}
545
546#[allow(clippy::uninlined_format_args)]
547pub fn create_score_prompt(params: Option<Vec<String>>) -> Prompt {
548    let mut user_prompt = "What is the score?".to_string();
549
550    // If parameters are provided, append them to the user prompt in format ${param}
551    if let Some(params) = params {
552        for param in params {
553            user_prompt.push_str(&format!(" ${{{}}}", param));
554        }
555    }
556
557    let system_content = "You are a helpful assistant.".to_string();
558
559    let system_msg = ChatMessage {
560        role: Role::Developer.to_string(),
561        content: vec![ContentPart::Text(TextContentPart::new(system_content))],
562        name: None,
563    };
564
565    let user_msg = ChatMessage {
566        role: Role::User.to_string(),
567        content: vec![ContentPart::Text(TextContentPart::new(user_prompt))],
568        name: None,
569    };
570    Prompt::new_rs(
571        vec![MessageNum::OpenAIMessageV1(user_msg)],
572        "gpt-4o",
573        potato_type::Provider::OpenAI,
574        vec![MessageNum::OpenAIMessageV1(system_msg)],
575        None,
576        Some(Score::get_structured_output_schema()),
577        potato_type::prompt::ResponseType::Score,
578    )
579    .unwrap()
580}
581
582pub fn create_parameterized_prompt() -> Prompt {
583    let user_content = "What is ${variable1} + ${variable2}?".to_string();
584    let system_content = "You are a helpful assistant.".to_string();
585
586    let system_msg = ChatMessage {
587        role: Role::Developer.to_string(),
588        content: vec![ContentPart::Text(TextContentPart::new(system_content))],
589        name: None,
590    };
591
592    let user_msg = ChatMessage {
593        role: Role::User.to_string(),
594        content: vec![ContentPart::Text(TextContentPart::new(user_content))],
595        name: None,
596    };
597    Prompt::new_rs(
598        vec![MessageNum::OpenAIMessageV1(user_msg)],
599        "gpt-4o",
600        potato_type::Provider::OpenAI,
601        vec![MessageNum::OpenAIMessageV1(system_msg)],
602        None,
603        None,
604        potato_type::prompt::ResponseType::Null,
605    )
606    .unwrap()
607}