openrouter_rs/api/
completion.rs

1use std::collections::HashMap;
2
3use derive_builder::Builder;
4use serde::{Deserialize, Serialize};
5use surf::http::headers::AUTHORIZATION;
6
7use crate::{
8    error::OpenRouterError,
9    strip_option_map_setter, strip_option_vec_setter,
10    types::{
11        ProviderPreferences, ReasoningConfig, ResponseFormat, completion::CompletionsResponse,
12    },
13    utils::handle_error,
14};
15
16#[derive(Serialize, Deserialize, Debug, Builder)]
17#[builder(build_fn(error = "OpenRouterError"))]
18pub struct CompletionRequest {
19    #[builder(setter(into))]
20    model: String,
21
22    #[builder(setter(into))]
23    prompt: String,
24
25    #[builder(setter(strip_option), default)]
26    #[serde(skip_serializing_if = "Option::is_none")]
27    stream: Option<bool>,
28
29    #[builder(setter(strip_option), default)]
30    #[serde(skip_serializing_if = "Option::is_none")]
31    max_tokens: Option<u32>,
32
33    #[builder(setter(strip_option), default)]
34    #[serde(skip_serializing_if = "Option::is_none")]
35    temperature: Option<f64>,
36
37    #[builder(setter(strip_option), default)]
38    #[serde(skip_serializing_if = "Option::is_none")]
39    seed: Option<u32>,
40
41    #[builder(setter(strip_option), default)]
42    #[serde(skip_serializing_if = "Option::is_none")]
43    top_p: Option<f64>,
44
45    #[builder(setter(strip_option), default)]
46    #[serde(skip_serializing_if = "Option::is_none")]
47    top_k: Option<u32>,
48
49    #[builder(setter(strip_option), default)]
50    #[serde(skip_serializing_if = "Option::is_none")]
51    frequency_penalty: Option<f64>,
52
53    #[builder(setter(strip_option), default)]
54    #[serde(skip_serializing_if = "Option::is_none")]
55    presence_penalty: Option<f64>,
56
57    #[builder(setter(strip_option), default)]
58    #[serde(skip_serializing_if = "Option::is_none")]
59    repetition_penalty: Option<f64>,
60
61    #[builder(setter(custom), default)]
62    #[serde(skip_serializing_if = "Option::is_none")]
63    logit_bias: Option<HashMap<String, f64>>,
64
65    #[builder(setter(strip_option), default)]
66    #[serde(skip_serializing_if = "Option::is_none")]
67    top_logprobs: Option<u32>,
68
69    #[builder(setter(strip_option), default)]
70    #[serde(skip_serializing_if = "Option::is_none")]
71    min_p: Option<f64>,
72
73    #[builder(setter(strip_option), default)]
74    #[serde(skip_serializing_if = "Option::is_none")]
75    top_a: Option<f64>,
76
77    #[builder(setter(custom), default)]
78    #[serde(skip_serializing_if = "Option::is_none")]
79    transforms: Option<Vec<String>>,
80
81    #[builder(setter(custom), default)]
82    #[serde(skip_serializing_if = "Option::is_none")]
83    models: Option<Vec<String>>,
84
85    #[builder(setter(into, strip_option), default)]
86    #[serde(skip_serializing_if = "Option::is_none")]
87    route: Option<String>,
88
89    #[builder(setter(strip_option), default)]
90    #[serde(skip_serializing_if = "Option::is_none")]
91    provider: Option<ProviderPreferences>,
92
93    #[builder(setter(strip_option), default)]
94    #[serde(skip_serializing_if = "Option::is_none")]
95    response_format: Option<ResponseFormat>,
96
97    #[builder(setter(strip_option), default)]
98    #[serde(skip_serializing_if = "Option::is_none")]
99    reasoning: Option<ReasoningConfig>,
100}
101
102impl CompletionRequestBuilder {
103    strip_option_vec_setter!(models, String);
104    strip_option_map_setter!(logit_bias, String, f64);
105    strip_option_vec_setter!(transforms, String);
106}
107
108impl CompletionRequest {
109    pub fn builder() -> CompletionRequestBuilder {
110        CompletionRequestBuilder::default()
111    }
112
113    pub fn new(model: &str, prompt: &str) -> Self {
114        Self::builder()
115            .model(model)
116            .prompt(prompt)
117            .build()
118            .expect("Failed to build CompletionRequest")
119    }
120}
121
122/// Send a completion request to a selected model (text-only format)
123///
124/// # Arguments
125///
126/// * `base_url` - The API URL for the request.
127/// * `api_key` - The API key for authentication.
128/// * `x_title` - The name of the site for the request.
129/// * `http_referer` - The URL of the site for the request.
130/// * `request` - The completion request containing the model, prompt, and other optional parameters.
131///
132/// # Returns
133///
134/// * `Result<CompletionsResponse, OpenRouterError>` - The response from the completion request, containing the generated text and other details.
135pub async fn send_completion_request(
136    base_url: &str,
137    api_key: &str,
138    x_title: &Option<String>,
139    http_referer: &Option<String>,
140    request: &CompletionRequest,
141) -> Result<CompletionsResponse, OpenRouterError> {
142    let url = format!("{base_url}/completions");
143
144    let mut surf_req = surf::post(url)
145        .header(AUTHORIZATION, format!("Bearer {api_key}"))
146        .body_json(request)?;
147
148    if let Some(x_title) = x_title {
149        surf_req = surf_req.header("X-Title", x_title);
150    }
151    if let Some(http_referer) = http_referer {
152        surf_req = surf_req.header("HTTP-Referer", http_referer);
153    }
154
155    let mut response = surf_req.await?;
156
157    if response.status().is_success() {
158        let completion_response = response.body_json().await?;
159        Ok(completion_response)
160    } else {
161        handle_error(response).await?;
162        unreachable!()
163    }
164}