openai_mock/handlers/
completion_handler.rs

1//! This module handles HTTP requests for generating text completions.
2//!
3//! It provides the `completions_handler` function, which processes incoming
4//! completion requests, validates them, and returns appropriate responses.
5
6use crate::models::{CompletionRequest, CompletionResponse, Usage};
7use crate::validators::{
8    validate_temperature, validate_top_p, validate_n, validate_max_tokens,
9    validate_presence_penalty, validate_frequency_penalty, validate_best_of,
10    validate_logprobs, validate_stop,
11};
12use crate::validators::StopSequence;
13use crate::validators::validate_required_fields;
14use actix_web::{web, HttpResponse, Responder};
15use serde_json::json;
16use crate::utils::utils::{generate_uuid, get_current_timestamp};
17use crate::utils::choices::create_choices;
18
19/// Handles the `/completions` endpoint for generating text completions.
20///
21/// This asynchronous function processes a `CompletionRequest`, validates
22/// the required and optional fields, generates completion choices, and
23/// constructs a `CompletionResponse`. In case of validation errors, it
24/// returns a `BadRequest` response with relevant error messages.
25///
26/// # Parameters
27///
28/// - `req`: A JSON payload deserialized into `CompletionRequest`.
29///
30/// # Returns
31///
32/// An `HttpResponse` containing the `CompletionResponse` on success or
33/// an error message on failure.
34pub async fn completions_handler(
35    req: web::Json<CompletionRequest>,
36) -> impl Responder {
37    // Validate the required fields using the validator
38    if let Err(validation_error) = validate_required_fields(&req) {
39        return HttpResponse::BadRequest().json(json!({
40            "error": {
41                "message": validation_error.to_string(),
42                "type": "invalid_request_error",
43                "param": "model",
44                "code": null,
45            }
46        }));
47    }
48
49    // Validate optional fields
50    let validators = [
51        ("temperature", validate_temperature(req.temperature)),
52        ("top_p", validate_top_p(req.top_p)),
53        ("n", validate_n(req.n)),
54        ("max_tokens", validate_max_tokens(req.max_tokens)),
55        ("presence_penalty", validate_presence_penalty(req.presence_penalty)),
56        ("frequency_penalty", validate_frequency_penalty(req.frequency_penalty)),
57        ("logprobs", validate_logprobs(req.logprobs)),
58        ("stop", validate_stop(req.stop.clone())),
59        ("best_of", validate_best_of(req.best_of, req.n)),
60    ];
61
62    // Check each validation result
63    for (field, result) in validators {
64        if let Err(validation_error) = result {
65            return HttpResponse::BadRequest().json(json!({
66                "error": {
67                    "message": validation_error,
68                    "type": "invalid_request_error",
69                    "param": field,
70                    "code": null,
71                }
72            }));
73        }
74    }
75
76    // Mock processing logic
77    let prompt = req.prompt.clone().unwrap_or_default();
78    let max_tokens = req.max_tokens.unwrap_or(16);
79    let n = req.n.unwrap_or(1);
80    let echo = req.echo.unwrap_or(false);
81    let logprobs = req.logprobs;
82
83    let stop_sequences = match &req.stop {
84        Some(StopSequence::Single(s)) => vec![s.clone()],
85        Some(StopSequence::Multiple(v)) => v.clone(),
86        None => Vec::new(),
87    };
88
89    let choices = create_choices(
90        n,
91        &prompt.to_string(),
92        &stop_sequences,
93        max_tokens,
94        echo,
95        logprobs,
96        &req.model
97    );
98
99    let response = CompletionResponse {
100        id: format!("cmpl-mock-id-{}", generate_uuid()),
101        object: "text_completion".to_string(),
102        created: get_current_timestamp().timestamp() as u64,
103        model: req.model.clone(),
104        choices,
105        usage: Usage {
106            prompt_tokens: count_tokens(&prompt.to_string()),
107            completion_tokens: max_tokens,
108            total_tokens: count_tokens(&prompt.to_string()) + max_tokens,
109        },
110    };
111    HttpResponse::Ok().json(response)
112}
113
114/// Counts the number of tokens in a given text.
115///
116/// This is a mock implementation that simply counts whitespace-separated
117/// words. In a real-world scenario, a proper tokenizer should be used.
118///
119/// # Parameters
120///
121/// - `text`: The input string to tokenize.
122///
123/// # Returns
124///
125/// The number of tokens as a `u32`.
126fn count_tokens(text: &str) -> u32 {
127    // This is a placeholder. In a real scenario, you might use a tokenizer.
128    text.split_whitespace().count() as u32
129}