openrouter_rs/api/
completion.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use surf::http::headers::AUTHORIZATION;
5
6use crate::{
7    error::OpenRouterError,
8    setter,
9    types::{ProviderPreferences, ReasoningConfig},
10    utils::handle_error,
11};
12
13#[derive(Serialize, Deserialize, Debug)]
14pub struct CompletionRequest {
15    model: String,
16    prompt: String,
17    stream: Option<bool>,
18    max_tokens: Option<u32>,
19    temperature: Option<f64>,
20    seed: Option<u32>,
21    top_p: Option<f64>,
22    top_k: Option<u32>,
23    frequency_penalty: Option<f64>,
24    presence_penalty: Option<f64>,
25    repetition_penalty: Option<f64>,
26    logit_bias: Option<HashMap<String, f64>>,
27    top_logprobs: Option<u32>,
28    min_p: Option<f64>,
29    top_a: Option<f64>,
30    transforms: Option<Vec<String>>,
31    models: Option<Vec<String>>,
32    route: Option<String>,
33    provider: Option<ProviderPreferences>,
34    reasoning: Option<ReasoningConfig>,
35}
36
37impl CompletionRequest {
38    pub fn new(model: &str, prompt: &str) -> Self {
39        Self {
40            model: model.to_string(),
41            prompt: prompt.to_string(),
42            stream: None,
43            max_tokens: None,
44            temperature: None,
45            seed: None,
46            top_p: None,
47            top_k: None,
48            frequency_penalty: None,
49            presence_penalty: None,
50            repetition_penalty: None,
51            logit_bias: None,
52            top_logprobs: None,
53            min_p: None,
54            top_a: None,
55            transforms: None,
56            models: None,
57            route: None,
58            provider: None,
59            reasoning: None,
60        }
61    }
62
63    setter!(stream, bool);
64    setter!(max_tokens, u32);
65    setter!(temperature, f64);
66    setter!(seed, u32);
67    setter!(top_p, f64);
68    setter!(top_k, u32);
69    setter!(frequency_penalty, f64);
70    setter!(presence_penalty, f64);
71    setter!(repetition_penalty, f64);
72    setter!(logit_bias, HashMap<String, f64>);
73    setter!(top_logprobs, u32);
74    setter!(min_p, f64);
75    setter!(top_a, f64);
76    setter!(transforms, Vec<String>);
77    setter!(models, Vec<String>);
78    setter!(route, String);
79    setter!(provider, ProviderPreferences);
80    setter!(reasoning, ReasoningConfig);
81}
82
83#[derive(Serialize, Deserialize, Debug)]
84pub struct CompletionResponse {
85    id: Option<String>,
86    choices: Option<Vec<Choice>>,
87}
88
89#[derive(Serialize, Deserialize, Debug)]
90pub struct Choice {
91    text: Option<String>,
92    index: Option<u32>,
93    finish_reason: Option<String>,
94}
95
96/// Send a completion request to a selected model (text-only format)
97///
98/// # Arguments
99///
100/// * `base_url` - The API URL for the request.
101/// * `api_key` - The API key for authentication.
102/// * `request` - The completion request containing the model, prompt, and other optional parameters.
103///
104/// # Returns
105///
106/// * `Result<CompletionResponse, OpenRouterError>` - The response from the completion request, containing the generated text and other details.
107pub async fn send_completion_request(
108    base_url: &str,
109    api_key: &str,
110    request: &CompletionRequest,
111) -> Result<CompletionResponse, OpenRouterError> {
112    let url = format!("{}/completions", base_url);
113
114    let mut response = surf::post(url)
115        .header(AUTHORIZATION, format!("Bearer {}", api_key))
116        .body_json(request)?
117        .await?;
118
119    if response.status().is_success() {
120        let completion_response = response.body_json().await?;
121        Ok(completion_response)
122    } else {
123        handle_error(response).await?;
124        unreachable!()
125    }
126}