1use crate::{
13 agent::AgentBuilder,
14 completion::{self, CompletionError},
15 extractor::ExtractorBuilder,
16 json_utils,
17};
18
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22
23const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai";
27
28#[derive(Clone)]
29pub struct Client {
30 base_url: String,
31 http_client: reqwest::Client,
32}
33
34impl Client {
35 pub fn new(api_key: &str) -> Self {
36 Self::from_url(api_key, PERPLEXITY_API_BASE_URL)
37 }
38
39 pub fn from_env() -> Self {
42 let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set");
43 Self::new(&api_key)
44 }
45
46 pub fn from_url(api_key: &str, base_url: &str) -> Self {
47 Self {
48 base_url: base_url.to_string(),
49 http_client: reqwest::Client::builder()
50 .default_headers({
51 let mut headers = reqwest::header::HeaderMap::new();
52 headers.insert(
53 "Authorization",
54 format!("Bearer {}", api_key)
55 .parse()
56 .expect("Bearer token should parse"),
57 );
58 headers
59 })
60 .build()
61 .expect("Perplexity reqwest client should build"),
62 }
63 }
64
65 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
66 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
67 self.http_client.post(url)
68 }
69
70 pub fn completion_model(&self, model: &str) -> CompletionModel {
71 CompletionModel::new(self.clone(), model)
72 }
73
74 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
75 AgentBuilder::new(self.completion_model(model))
76 }
77
78 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
79 &self,
80 model: &str,
81 ) -> ExtractorBuilder<T, CompletionModel> {
82 ExtractorBuilder::new(self.completion_model(model))
83 }
84}
85
86#[derive(Debug, Deserialize)]
87struct ApiErrorResponse {
88 message: String,
89}
90
91#[derive(Debug, Deserialize)]
92#[serde(untagged)]
93enum ApiResponse<T> {
94 Ok(T),
95 Err(ApiErrorResponse),
96}
97
98pub const LLAMA_3_1_SONAR_SMALL_ONLINE: &str = "llama-3.1-sonar-small-128k-online";
103pub const LLAMA_3_1_SONAR_LARGE_ONLINE: &str = "llama-3.1-sonar-large-128k-online";
105pub const LLAMA_3_1_SONAR_HUGE_ONLINE: &str = "llama-3.1-sonar-huge-128k-online";
107pub const LLAMA_3_1_SONAR_SMALL_CHAT: &str = "llama-3.1-sonar-small-128k-chat";
109pub const LLAMA_3_1_SONAR_LARGE_CHAT: &str = "llama-3.1-sonar-large-128k-chat";
111pub const LLAMA_3_1_8B_INSTRUCT: &str = "llama-3.1-8b-instruct";
113pub const LLAMA_3_1_70B_INSTRUCT: &str = "llama-3.1-70b-instruct";
115
116#[derive(Debug, Deserialize)]
117pub struct CompletionResponse {
118 pub id: String,
119 pub model: String,
120 pub object: String,
121 pub created: u64,
122 #[serde(default)]
123 pub choices: Vec<Choice>,
124 pub usage: Usage,
125}
126
127#[derive(Deserialize, Debug)]
128pub struct Message {
129 pub role: String,
130 pub content: String,
131}
132
133#[derive(Deserialize, Debug)]
134pub struct Delta {
135 pub role: String,
136 pub content: String,
137}
138
139#[derive(Deserialize, Debug)]
140pub struct Choice {
141 pub index: usize,
142 pub finish_reason: String,
143 pub message: Message,
144 pub delta: Delta,
145}
146
147#[derive(Deserialize, Debug)]
148pub struct Usage {
149 pub prompt_tokens: u32,
150 pub completion_tokens: u32,
151 pub total_tokens: u32,
152}
153
154impl std::fmt::Display for Usage {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 write!(
157 f,
158 "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}",
159 self.prompt_tokens, self.completion_tokens, self.total_tokens
160 )
161 }
162}
163
164impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
165 type Error = CompletionError;
166
167 fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
168 match value.choices.as_slice() {
169 [Choice {
170 message: Message { content, .. },
171 ..
172 }, ..] => Ok(completion::CompletionResponse {
173 choice: completion::ModelChoice::Message(content.to_string()),
174 raw_response: value,
175 }),
176 _ => Err(CompletionError::ResponseError(
177 "Response did not contain a message or tool call".into(),
178 )),
179 }
180 }
181}
182
183#[derive(Clone)]
184pub struct CompletionModel {
185 client: Client,
186 pub model: String,
187}
188
189impl CompletionModel {
190 pub fn new(client: Client, model: &str) -> Self {
191 Self {
192 client,
193 model: model.to_string(),
194 }
195 }
196}
197
198impl completion::CompletionModel for CompletionModel {
199 type Response = CompletionResponse;
200
201 async fn completion(
202 &self,
203 completion_request: completion::CompletionRequest,
204 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
205 let mut messages = if let Some(preamble) = &completion_request.preamble {
207 vec![completion::Message {
208 role: "system".into(),
209 content: preamble.clone(),
210 }]
211 } else {
212 vec![]
213 };
214
215 let prompt_with_context = completion_request.prompt_with_context();
217
218 messages.extend(completion_request.chat_history);
220
221 messages.push(completion::Message {
223 role: "user".to_string(),
224 content: prompt_with_context,
225 });
226
227 let request = json!({
228 "model": self.model,
229 "messages": messages,
230 "temperature": completion_request.temperature,
231 });
232
233 let response = self
234 .client
235 .post("/chat/completions")
236 .json(
237 &if let Some(ref params) = completion_request.additional_params {
238 json_utils::merge(request.clone(), params.clone())
239 } else {
240 request.clone()
241 },
242 )
243 .send()
244 .await?;
245
246 if response.status().is_success() {
247 match response.json::<ApiResponse<CompletionResponse>>().await? {
248 ApiResponse::Ok(completion) => {
249 tracing::info!(target: "rig",
250 "Perplexity completion token usage: {}",
251 completion.usage
252 );
253 Ok(completion.try_into()?)
254 }
255 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
256 }
257 } else {
258 Err(CompletionError::ProviderError(response.text().await?))
259 }
260 }
261}