1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use surf::http::headers::AUTHORIZATION;
5
6use crate::{
7 error::OpenRouterError,
8 setter,
9 types::{ProviderPreferences, ReasoningConfig},
10 utils::handle_error,
11};
12
13#[derive(Serialize, Deserialize, Debug)]
14pub struct CompletionRequest {
15 model: String,
16 prompt: String,
17 stream: Option<bool>,
18 max_tokens: Option<u32>,
19 temperature: Option<f64>,
20 seed: Option<u32>,
21 top_p: Option<f64>,
22 top_k: Option<u32>,
23 frequency_penalty: Option<f64>,
24 presence_penalty: Option<f64>,
25 repetition_penalty: Option<f64>,
26 logit_bias: Option<HashMap<String, f64>>,
27 top_logprobs: Option<u32>,
28 min_p: Option<f64>,
29 top_a: Option<f64>,
30 transforms: Option<Vec<String>>,
31 models: Option<Vec<String>>,
32 route: Option<String>,
33 provider: Option<ProviderPreferences>,
34 reasoning: Option<ReasoningConfig>,
35}
36
37#[derive(Default)]
38pub struct CompletionRequestBuilder {
39 model: Option<String>,
40 prompt: Option<String>,
41 stream: Option<bool>,
42 max_tokens: Option<u32>,
43 temperature: Option<f64>,
44 seed: Option<u32>,
45 top_p: Option<f64>,
46 top_k: Option<u32>,
47 frequency_penalty: Option<f64>,
48 presence_penalty: Option<f64>,
49 repetition_penalty: Option<f64>,
50 logit_bias: Option<HashMap<String, f64>>,
51 top_logprobs: Option<u32>,
52 min_p: Option<f64>,
53 top_a: Option<f64>,
54 transforms: Option<Vec<String>>,
55 models: Option<Vec<String>>,
56 route: Option<String>,
57 provider: Option<ProviderPreferences>,
58 reasoning: Option<ReasoningConfig>,
59}
60
61impl CompletionRequestBuilder {
62 pub fn new() -> Self {
63 Self::default()
64 }
65
66 setter!(model, into String);
67 setter!(prompt, into String);
68 setter!(stream, bool);
69 setter!(max_tokens, u32);
70 setter!(temperature, f64);
71 setter!(seed, u32);
72 setter!(top_p, f64);
73 setter!(top_k, u32);
74 setter!(frequency_penalty, f64);
75 setter!(presence_penalty, f64);
76 setter!(repetition_penalty, f64);
77 setter!(logit_bias, HashMap<String, f64>);
78 setter!(top_logprobs, u32);
79 setter!(min_p, f64);
80 setter!(top_a, f64);
81 setter!(transforms, Vec<String>);
82 setter!(models, Vec<String>);
83 setter!(route, String);
84 setter!(provider, ProviderPreferences);
85 setter!(reasoning, ReasoningConfig);
86
87 pub fn build(self) -> Result<CompletionRequest, OpenRouterError> {
88 Ok(CompletionRequest {
89 model: self
90 .model
91 .ok_or(OpenRouterError::Validation("model is required".into()))?,
92 prompt: self
93 .prompt
94 .ok_or(OpenRouterError::Validation("prompt is required".into()))?,
95 stream: self.stream,
96 max_tokens: self.max_tokens,
97 temperature: self.temperature,
98 seed: self.seed,
99 top_p: self.top_p,
100 top_k: self.top_k,
101 frequency_penalty: self.frequency_penalty,
102 presence_penalty: self.presence_penalty,
103 repetition_penalty: self.repetition_penalty,
104 logit_bias: self.logit_bias,
105 top_logprobs: self.top_logprobs,
106 min_p: self.min_p,
107 top_a: self.top_a,
108 transforms: self.transforms,
109 models: self.models,
110 route: self.route,
111 provider: self.provider,
112 reasoning: self.reasoning,
113 })
114 }
115}
116
117impl CompletionRequest {
118 pub fn builder() -> CompletionRequestBuilder {
119 CompletionRequestBuilder::new()
120 }
121
122 pub fn new(model: &str, prompt: &str) -> Self {
123 Self::builder()
124 .model(model)
125 .prompt(prompt)
126 .build()
127 .expect("Failed to build CompletionRequest")
128 }
129}
130
131#[derive(Serialize, Deserialize, Debug)]
132pub struct CompletionResponse {
133 pub id: Option<String>,
134 pub choices: Option<Vec<Choice>>,
135}
136
137#[derive(Serialize, Deserialize, Debug)]
138pub struct Choice {
139 pub text: Option<String>,
140 pub index: Option<u32>,
141 pub finish_reason: Option<String>,
142}
143
144pub async fn send_completion_request(
156 base_url: &str,
157 api_key: &str,
158 request: &CompletionRequest,
159) -> Result<CompletionResponse, OpenRouterError> {
160 let url = format!("{}/completions", base_url);
161
162 let mut response = surf::post(url)
163 .header(AUTHORIZATION, format!("Bearer {}", api_key))
164 .body_json(request)?
165 .await?;
166
167 if response.status().is_success() {
168 let completion_response = response.body_json().await?;
169 Ok(completion_response)
170 } else {
171 handle_error(response).await?;
172 unreachable!()
173 }
174}