openai_mock/handlers/
completion_handler.rs1use crate::models::{CompletionRequest, CompletionResponse, Usage};
7use crate::validators::{
8 validate_temperature, validate_top_p, validate_n, validate_max_tokens,
9 validate_presence_penalty, validate_frequency_penalty, validate_best_of,
10 validate_logprobs, validate_stop,
11};
12use crate::validators::StopSequence;
13use crate::validators::validate_required_fields;
14use actix_web::{web, HttpResponse, Responder};
15use serde_json::json;
16use crate::utils::utils::{generate_uuid, get_current_timestamp};
17use crate::utils::choices::create_choices;
18
19pub async fn completions_handler(
35 req: web::Json<CompletionRequest>,
36) -> impl Responder {
37 if let Err(validation_error) = validate_required_fields(&req) {
39 return HttpResponse::BadRequest().json(json!({
40 "error": {
41 "message": validation_error.to_string(),
42 "type": "invalid_request_error",
43 "param": "model",
44 "code": null,
45 }
46 }));
47 }
48
49 let validators = [
51 ("temperature", validate_temperature(req.temperature)),
52 ("top_p", validate_top_p(req.top_p)),
53 ("n", validate_n(req.n)),
54 ("max_tokens", validate_max_tokens(req.max_tokens)),
55 ("presence_penalty", validate_presence_penalty(req.presence_penalty)),
56 ("frequency_penalty", validate_frequency_penalty(req.frequency_penalty)),
57 ("logprobs", validate_logprobs(req.logprobs)),
58 ("stop", validate_stop(req.stop.clone())),
59 ("best_of", validate_best_of(req.best_of, req.n)),
60 ];
61
62 for (field, result) in validators {
64 if let Err(validation_error) = result {
65 return HttpResponse::BadRequest().json(json!({
66 "error": {
67 "message": validation_error,
68 "type": "invalid_request_error",
69 "param": field,
70 "code": null,
71 }
72 }));
73 }
74 }
75
76 let prompt = req.prompt.clone().unwrap_or_default();
78 let max_tokens = req.max_tokens.unwrap_or(16);
79 let n = req.n.unwrap_or(1);
80 let echo = req.echo.unwrap_or(false);
81 let logprobs = req.logprobs;
82
83 let stop_sequences = match &req.stop {
84 Some(StopSequence::Single(s)) => vec![s.clone()],
85 Some(StopSequence::Multiple(v)) => v.clone(),
86 None => Vec::new(),
87 };
88
89 let choices = create_choices(
90 n,
91 &prompt.to_string(),
92 &stop_sequences,
93 max_tokens,
94 echo,
95 logprobs,
96 &req.model
97 );
98
99 let response = CompletionResponse {
100 id: format!("cmpl-mock-id-{}", generate_uuid()),
101 object: "text_completion".to_string(),
102 created: get_current_timestamp().timestamp() as u64,
103 model: req.model.clone(),
104 choices,
105 usage: Usage {
106 prompt_tokens: count_tokens(&prompt.to_string()),
107 completion_tokens: max_tokens,
108 total_tokens: count_tokens(&prompt.to_string()) + max_tokens,
109 },
110 };
111 HttpResponse::Ok().json(response)
112}
113
114fn count_tokens(text: &str) -> u32 {
127 text.split_whitespace().count() as u32
129}