1use crate::error::MockError;
2use mockito;
3use potato_provider::{GenerateContentResponse, OpenAIChatResponse};
4use potato_type::google::GeminiEmbeddingResponse;
5use potato_type::openai::embedding::OpenAIEmbeddingResponse;
6use pyo3::prelude::*;
7use rand::Rng;
8use serde_json;
9
10pub const OPENAI_EMBEDDING_RESPONSE: &str = include_str!("assets/openai/embedding_response.json");
11
12pub const GEMINI_EMBEDDING_RESPONSE: &str = include_str!("assets/gemini/embedding_response.json");
13
14pub const OPENAI_CHAT_COMPLETION_RESPONSE: &str =
15 include_str!("assets/openai/openai_chat_completion_response.json");
16
17pub const OPENAI_CHAT_STRUCTURED_RESPONSE: &str =
18 include_str!("assets/openai/chat_completion_structured_response.json");
19
20pub const OPENAI_CHAT_STRUCTURED_SCORE_RESPONSE: &str =
21 include_str!("assets/openai/chat_completion_structured_score_response.json");
22
23pub const OPENAI_CHAT_STRUCTURED_RESPONSE_PARAMS: &str =
24 include_str!("assets/openai/chat_completion_structured_response_params.json");
25
26pub const OPENAI_CHAT_STRUCTURED_TASK_OUTPUT: &str =
27 include_str!("assets/openai/chat_completion_structured_task_output.json");
28
29pub const GEMINI_CHAT_COMPLETION_RESPONSE: &str =
30 include_str!("assets/gemini/chat_completion.json");
31
32pub const GEMINI_CHAT_COMPLETION_RESPONSE_WITH_SCORE: &str =
33 include_str!("assets/gemini/chat_completion_with_score.json");
34
35fn randomize_openai_embedding_response(
36 response: OpenAIEmbeddingResponse,
37) -> OpenAIEmbeddingResponse {
38 let mut cloned_response = response.clone();
40 let mut rng = rand::rng();
41 let embedding: Vec<f32> = (0..512).map(|_| rng.random_range(-1.0..1.0)).collect();
42 cloned_response.data[0].embedding = embedding;
43 cloned_response
44}
45
46fn randomize_gemini_embedding_response(
47 response: GeminiEmbeddingResponse,
48) -> GeminiEmbeddingResponse {
49 let mut cloned_response = response.clone();
50 let mut rng = rand::rng();
51 let embedding: Vec<f32> = (0..512).map(|_| rng.random_range(-1.0..1.0)).collect();
52 cloned_response.embedding.values = embedding;
53 cloned_response
54}
55
56fn randomize_structured_openai_score_response(response: &OpenAIChatResponse) -> OpenAIChatResponse {
57 let mut cloned_response = response.clone();
58 let mut rng = rand::rng();
59
60 let score = rng.random_range(1..=5);
62
63 let reasons = [
65 "The code is excellent and follows best practices.",
66 "The implementation is solid with minor improvements possible.",
67 "The code works but could use some optimization.",
68 "The solution is functional but needs refactoring.",
69 "The code has significant issues that need addressing.",
70 ];
71 let reason = reasons[rng.random_range(0..reasons.len())];
72
73 cloned_response.choices[0].message.content = Some(format!(
74 "{{ \"score\": {}, \"reason\": \"{}\" }}",
75 score, reason
76 ));
77
78 cloned_response
79}
80
81fn randomize_gemini_score_response(response: GenerateContentResponse) -> GenerateContentResponse {
82 let mut cloned_response = response.clone();
83 let mut rng = rand::rng();
84
85 let score = rng.random_range(1..=100);
87
88 let reasons = [
90 "The model performed exceptionally well on the evaluation.",
91 "Good performance with room for minor improvements.",
92 "Satisfactory results with some areas for optimization.",
93 "Adequate performance but needs significant improvements.",
94 "Performance below expectations, requires major adjustments.",
95 ];
96 let reason = reasons[rng.random_range(0..reasons.len())];
97
98 if let Some(candidate) = cloned_response.candidates.get_mut(0) {
100 if let Some(part) = candidate.content.parts.get_mut(0) {
101 part.text = Some(format!(
102 "{{\"score\": {}, \"reason\": \"{}\"}}",
103 score, reason
104 ));
105 }
106 }
107
108 cloned_response
109}
110
111pub struct LLMApiMock {
112 pub url: String,
113 pub server: mockito::ServerGuard,
114}
115
116impl LLMApiMock {
117 pub fn new() -> Self {
118 let mut server = mockito::Server::new();
119 let openai_embedding_response: OpenAIEmbeddingResponse =
121 serde_json::from_str(OPENAI_EMBEDDING_RESPONSE).unwrap();
122 let chat_msg_response: OpenAIChatResponse =
123 serde_json::from_str(OPENAI_CHAT_COMPLETION_RESPONSE).unwrap();
124 let chat_structured_response: OpenAIChatResponse =
125 serde_json::from_str(OPENAI_CHAT_STRUCTURED_RESPONSE).unwrap();
126 let chat_structured_score_response: OpenAIChatResponse =
127 serde_json::from_str(OPENAI_CHAT_STRUCTURED_SCORE_RESPONSE).unwrap();
128 let chat_structured_response_params: OpenAIChatResponse =
129 serde_json::from_str(OPENAI_CHAT_STRUCTURED_RESPONSE_PARAMS).unwrap();
130 let chat_structured_task_output: OpenAIChatResponse =
131 serde_json::from_str(OPENAI_CHAT_STRUCTURED_TASK_OUTPUT).unwrap();
132
133 let gemini_chat_response: GenerateContentResponse =
135 serde_json::from_str(GEMINI_CHAT_COMPLETION_RESPONSE).unwrap();
136 let gemini_chat_response_with_score: GenerateContentResponse =
137 serde_json::from_str(GEMINI_CHAT_COMPLETION_RESPONSE_WITH_SCORE).unwrap();
138 let gemini_embedding_response: GeminiEmbeddingResponse =
139 serde_json::from_str(GEMINI_EMBEDDING_RESPONSE).unwrap();
140
141 server
142 .mock("POST", "/chat/completions")
143 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
144 "response_format": {
145 "type": "json_schema",
146 "json_schema": {
147 "name": "Parameters",
148 "schema": {
149 "$schema": "https://json-schema.org/draft/2020-12/schema",
150 "properties": {
151 "variable1": {
152 "format": "int32",
153 "type": "integer"
154 },
155 "variable2": {
156 "format": "int32",
157 "type": "integer"
158 }
159 },
160 "required": [
161 "variable1",
162 "variable2"
163 ],
164 "title": "Parameters",
165 "type": "object"
166 },
167 "strict": true
168 }
169
170 }
171 })))
172 .expect(usize::MAX)
173 .with_status(200)
174 .with_header("content-type", "application/json")
175 .with_body(serde_json::to_string(&chat_structured_response_params).unwrap())
176 .create();
177
178 server
179 .mock("POST", "/chat/completions")
180 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
181 "response_format": {
182 "type": "json_schema",
183 "json_schema": {
184 "name": "TaskOutput",
185 }
186 }
187 })))
188 .expect(usize::MAX)
189 .with_status(200)
190 .with_header("content-type", "application/json")
191 .with_body(serde_json::to_string(&chat_structured_task_output).unwrap())
192 .create();
193
194 server
195 .mock("POST", "/chat/completions")
196 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
197 "response_format": {
198 "type": "json_schema",
199 "json_schema": {
200 "name": "Score",
201 }
202 }
203 })))
204 .expect(usize::MAX)
205 .with_status(200)
206 .with_header("content-type", "application/json")
207 .with_body(serde_json::to_string(&chat_structured_score_response).unwrap())
208 .with_body_from_request({
209 let chat_structured_score_response = chat_structured_score_response.clone();
210 move |_request| {
211 let randomized_response = randomize_structured_openai_score_response(
212 &chat_structured_score_response.clone(),
213 );
214 serde_json::to_string(&randomized_response).unwrap().into()
215 }
216 })
217 .create();
218
219 server
220 .mock("POST", "/chat/completions")
221 .match_body(mockito::Matcher::Regex(
222 r#".*"name"\s*:\s*"Score".*"#.to_string(),
223 ))
224 .expect(usize::MAX)
225 .with_status(200)
226 .with_header("content-type", "application/json")
227 .with_body_from_request({
228 let chat_structured_score_response = chat_structured_score_response.clone();
229 move |_request| {
230 let randomized_response = randomize_structured_openai_score_response(
231 &chat_structured_score_response.clone(),
232 );
233 serde_json::to_string(&randomized_response).unwrap().into()
234 }
235 })
236 .create();
237
238 server
239 .mock("POST", "/chat/completions")
240 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
241 "response_format": {
242 "type": "json_schema"
243 }
244 })))
245 .expect(usize::MAX)
246 .with_status(200)
247 .with_header("content-type", "application/json")
248 .with_body(serde_json::to_string(&chat_structured_response).unwrap())
249 .create();
250
251 server
253 .mock(
254 "POST",
255 mockito::Matcher::Regex(r".*/.*:generateContent$".to_string()),
256 )
257 .match_header("x-goog-api-key", mockito::Matcher::Any)
258 .match_header("content-type", "application/json")
259 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
260 "contents": [
261 {
262 "parts": [
263 {
264 "text": "You are a helpful assistant"
265 }
266 ]
267 }
268 ]
269 })))
270 .expect(usize::MAX) .with_status(200)
272 .with_header("content-type", "application/json")
273 .with_body(serde_json::to_string(&gemini_chat_response).unwrap())
274 .create();
275
276 server
278 .mock(
279 "POST",
280 mockito::Matcher::Regex(r".*/.*:generateContent$".to_string()),
281 )
282 .match_header("x-goog-api-key", mockito::Matcher::Any)
283 .match_header("content-type", "application/json")
284 .match_body(mockito::Matcher::PartialJson(serde_json::json!({
285 "generation_config": {
286 "responseMimeType": "application/json"
287 }
288 })))
289 .expect(usize::MAX)
290 .with_status(200)
291 .with_header("content-type", "application/json")
292 .with_body_from_request(move |_request| {
293 let randomized_response =
294 randomize_gemini_score_response(gemini_chat_response_with_score.clone());
295 serde_json::to_string(&randomized_response).unwrap().into()
296 })
297 .create();
298
299 server
301 .mock("POST", "/chat/completions")
302 .expect(usize::MAX)
303 .with_status(200)
304 .with_header("content-type", "application/json")
305 .with_body(serde_json::to_string(&chat_msg_response).unwrap())
306 .create();
307
308 server
309 .mock("POST", "/embeddings")
310 .expect(usize::MAX)
311 .with_status(200)
312 .with_header("content-type", "application/json")
313 .with_body_from_request(move |_request| {
314 let randomized_response =
315 randomize_openai_embedding_response(openai_embedding_response.clone());
316 serde_json::to_string(&randomized_response).unwrap().into()
317 })
318 .create();
319
320 server
321 .mock(
322 "POST",
323 mockito::Matcher::Regex(r".*/.*:embedContent$".to_string()),
324 )
325 .expect(usize::MAX)
326 .with_status(200)
327 .with_header("content-type", "application/json")
328 .with_body_from_request(move |_request| {
329 let randomized_response =
330 randomize_gemini_embedding_response(gemini_embedding_response.clone());
331 serde_json::to_string(&randomized_response).unwrap().into()
332 })
333 .create();
334
335 Self {
336 url: server.url(),
337 server,
338 }
339 }
340}
341
342impl Default for LLMApiMock {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348#[pyclass]
349#[allow(dead_code)]
350pub struct LLMTestServer {
351 openai_server: Option<LLMApiMock>,
352
353 #[pyo3(get)]
354 pub url: Option<String>,
355}
356
357#[pymethods]
358impl LLMTestServer {
359 #[new]
360 pub fn new() -> Self {
361 LLMTestServer {
362 openai_server: None,
363 url: None,
364 }
365 }
366
367 pub fn start_mock_server(&mut self) -> Result<(), MockError> {
368 let llm_server = LLMApiMock::new();
369 println!("Mock LLM Server started at {}", llm_server.url);
370 self.openai_server = Some(llm_server);
371 Ok(())
372 }
373
374 pub fn stop_mock_server(&mut self) {
375 if let Some(server) = self.openai_server.take() {
376 drop(server);
377 std::env::remove_var("OPENAI_API_URL");
378 std::env::remove_var("OPENAI_API_KEY");
379 }
380 println!("Mock LLM Server stopped");
381 }
382
383 pub fn set_env_vars_for_client(&self) -> Result<(), MockError> {
384 {
385 std::env::set_var("APP_ENV", "dev_client");
386 std::env::set_var("OPENAI_API_KEY", "test_key");
387 std::env::set_var("GEMINI_API_KEY", "gemini");
388 std::env::set_var(
389 "OPENAI_API_URL",
390 self.openai_server.as_ref().unwrap().url.clone(),
391 );
392 std::env::set_var(
393 "GEMINI_API_URL",
394 self.openai_server.as_ref().unwrap().url.clone(),
395 );
396
397 Ok(())
398 }
399 }
400
401 pub fn start_server(&mut self) -> Result<(), MockError> {
402 self.cleanup()?;
403
404 println!("Starting Mock GenAI Server...");
405 self.start_mock_server()?;
406 self.set_env_vars_for_client()?;
407
408 std::env::set_var("APP_ENV", "dev_server");
410
411 self.url = Some(self.openai_server.as_ref().unwrap().url.clone());
412
413 Ok(())
414 }
415
416 pub fn stop_server(&mut self) -> Result<(), MockError> {
417 self.cleanup()?;
418
419 Ok(())
420 }
421
422 pub fn remove_env_vars_for_client(&self) -> Result<(), MockError> {
423 std::env::remove_var("OPENAI_API_URI");
424 std::env::remove_var("OPENAI_API_KEY");
425 std::env::remove_var("GEMINI_API_KEY");
426 std::env::remove_var("GEMINI_API_URL");
427 Ok(())
428 }
429
430 fn cleanup(&self) -> Result<(), MockError> {
431 self.remove_env_vars_for_client()?;
433
434 Ok(())
435 }
436
437 fn __enter__(mut self_: PyRefMut<Self>) -> Result<PyRefMut<Self>, MockError> {
438 self_.start_server()?;
439
440 Ok(self_)
441 }
442
443 fn __exit__(
444 &mut self,
445 _exc_type: Py<PyAny>,
446 _exc_value: Py<PyAny>,
447 _traceback: Py<PyAny>,
448 ) -> Result<(), MockError> {
449 self.stop_server()
450 }
451}
452
453impl Default for LLMTestServer {
454 fn default() -> Self {
455 Self::new()
456 }
457}