openrouter_rs/api/
completion.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use surf::http::headers::AUTHORIZATION;
5
6use crate::{
7    error::OpenRouterError,
8    setter,
9    types::{ProviderPreferences, ReasoningConfig},
10    utils::handle_error,
11};
12
13#[derive(Serialize, Deserialize, Debug)]
14pub struct CompletionRequest {
15    model: String,
16    prompt: String,
17    stream: Option<bool>,
18    max_tokens: Option<u32>,
19    temperature: Option<f64>,
20    seed: Option<u32>,
21    top_p: Option<f64>,
22    top_k: Option<u32>,
23    frequency_penalty: Option<f64>,
24    presence_penalty: Option<f64>,
25    repetition_penalty: Option<f64>,
26    logit_bias: Option<HashMap<String, f64>>,
27    top_logprobs: Option<u32>,
28    min_p: Option<f64>,
29    top_a: Option<f64>,
30    transforms: Option<Vec<String>>,
31    models: Option<Vec<String>>,
32    route: Option<String>,
33    provider: Option<ProviderPreferences>,
34    reasoning: Option<ReasoningConfig>,
35}
36
37#[derive(Default)]
38pub struct CompletionRequestBuilder {
39    model: Option<String>,
40    prompt: Option<String>,
41    stream: Option<bool>,
42    max_tokens: Option<u32>,
43    temperature: Option<f64>,
44    seed: Option<u32>,
45    top_p: Option<f64>,
46    top_k: Option<u32>,
47    frequency_penalty: Option<f64>,
48    presence_penalty: Option<f64>,
49    repetition_penalty: Option<f64>,
50    logit_bias: Option<HashMap<String, f64>>,
51    top_logprobs: Option<u32>,
52    min_p: Option<f64>,
53    top_a: Option<f64>,
54    transforms: Option<Vec<String>>,
55    models: Option<Vec<String>>,
56    route: Option<String>,
57    provider: Option<ProviderPreferences>,
58    reasoning: Option<ReasoningConfig>,
59}
60
61impl CompletionRequestBuilder {
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    setter!(model, into String);
67    setter!(prompt, into String);
68    setter!(stream, bool);
69    setter!(max_tokens, u32);
70    setter!(temperature, f64);
71    setter!(seed, u32);
72    setter!(top_p, f64);
73    setter!(top_k, u32);
74    setter!(frequency_penalty, f64);
75    setter!(presence_penalty, f64);
76    setter!(repetition_penalty, f64);
77    setter!(logit_bias, HashMap<String, f64>);
78    setter!(top_logprobs, u32);
79    setter!(min_p, f64);
80    setter!(top_a, f64);
81    setter!(transforms, Vec<String>);
82    setter!(models, Vec<String>);
83    setter!(route, String);
84    setter!(provider, ProviderPreferences);
85    setter!(reasoning, ReasoningConfig);
86
87    pub fn build(self) -> Result<CompletionRequest, OpenRouterError> {
88        Ok(CompletionRequest {
89            model: self
90                .model
91                .ok_or(OpenRouterError::Validation("model is required".into()))?,
92            prompt: self
93                .prompt
94                .ok_or(OpenRouterError::Validation("prompt is required".into()))?,
95            stream: self.stream,
96            max_tokens: self.max_tokens,
97            temperature: self.temperature,
98            seed: self.seed,
99            top_p: self.top_p,
100            top_k: self.top_k,
101            frequency_penalty: self.frequency_penalty,
102            presence_penalty: self.presence_penalty,
103            repetition_penalty: self.repetition_penalty,
104            logit_bias: self.logit_bias,
105            top_logprobs: self.top_logprobs,
106            min_p: self.min_p,
107            top_a: self.top_a,
108            transforms: self.transforms,
109            models: self.models,
110            route: self.route,
111            provider: self.provider,
112            reasoning: self.reasoning,
113        })
114    }
115}
116
117impl CompletionRequest {
118    pub fn builder() -> CompletionRequestBuilder {
119        CompletionRequestBuilder::new()
120    }
121
122    pub fn new(model: &str, prompt: &str) -> Self {
123        Self::builder()
124            .model(model)
125            .prompt(prompt)
126            .build()
127            .expect("Failed to build CompletionRequest")
128    }
129}
130
131#[derive(Serialize, Deserialize, Debug)]
132pub struct CompletionResponse {
133    pub id: Option<String>,
134    pub choices: Option<Vec<Choice>>,
135}
136
137#[derive(Serialize, Deserialize, Debug)]
138pub struct Choice {
139    pub text: Option<String>,
140    pub index: Option<u32>,
141    pub finish_reason: Option<String>,
142}
143
144/// Send a completion request to a selected model (text-only format)
145///
146/// # Arguments
147///
148/// * `base_url` - The API URL for the request.
149/// * `api_key` - The API key for authentication.
150/// * `request` - The completion request containing the model, prompt, and other optional parameters.
151///
152/// # Returns
153///
154/// * `Result<CompletionResponse, OpenRouterError>` - The response from the completion request, containing the generated text and other details.
155pub async fn send_completion_request(
156    base_url: &str,
157    api_key: &str,
158    request: &CompletionRequest,
159) -> Result<CompletionResponse, OpenRouterError> {
160    let url = format!("{}/completions", base_url);
161
162    let mut response = surf::post(url)
163        .header(AUTHORIZATION, format!("Bearer {}", api_key))
164        .body_json(request)?
165        .await?;
166
167    if response.status().is_success() {
168        let completion_response = response.body_json().await?;
169        Ok(completion_response)
170    } else {
171        handle_error(response).await?;
172        unreachable!()
173    }
174}