llm_connector/core/
traits.rs1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::any::Any;
8
9use crate::types::{ChatRequest, ChatResponse};
11use crate::error::LlmConnectorError;
12
13#[cfg(feature = "streaming")]
14use crate::types::ChatStream;
15
16#[async_trait]
21pub trait Protocol: Send + Sync + Clone + 'static {
22 type Request: Serialize + Send + Sync;
24
25 type Response: for<'de> Deserialize<'de> + Send + Sync;
27
28 fn name(&self) -> &str;
30
31 fn chat_endpoint(&self, base_url: &str) -> String;
33
34 fn models_endpoint(&self, _base_url: &str) -> Option<String> {
36 None
37 }
38
39 fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError>;
41
42 fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError>;
44
45 fn parse_models(&self, _response: &str) -> Result<Vec<String>, LlmConnectorError> {
47 Err(LlmConnectorError::UnsupportedOperation(
48 format!("{} does not support model listing", self.name())
49 ))
50 }
51
52 fn map_error(&self, status: u16, body: &str) -> LlmConnectorError;
54
55 fn auth_headers(&self) -> Vec<(String, String)> {
57 Vec::new()
58 }
59
60 #[cfg(feature = "streaming")]
62 async fn parse_stream_response(&self, response: reqwest::Response) -> Result<ChatStream, LlmConnectorError> {
63 Ok(crate::sse::sse_to_streaming_response(response))
65 }
66}
67
68#[async_trait]
73pub trait Provider: Send + Sync {
74 fn name(&self) -> &str;
76
77 async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError>;
79
80 #[cfg(feature = "streaming")]
82 async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError>;
83
84 async fn models(&self) -> Result<Vec<String>, LlmConnectorError>;
86
87 fn as_any(&self) -> &dyn Any;
89}
90
91pub struct GenericProvider<P: Protocol> {
97 protocol: P,
98 client: super::HttpClient,
99}
100
101impl<P: Protocol> GenericProvider<P> {
102 pub fn new(protocol: P, client: super::HttpClient) -> Self {
104 Self { protocol, client }
105 }
106
107 pub fn protocol(&self) -> &P {
109 &self.protocol
110 }
111
112 pub fn client(&self) -> &super::HttpClient {
114 &self.client
115 }
116}
117
118impl<P: Protocol> Clone for GenericProvider<P> {
119 fn clone(&self) -> Self {
120 Self {
121 protocol: self.protocol.clone(),
122 client: self.client.clone(),
123 }
124 }
125}
126
127#[async_trait]
128impl<P: Protocol> Provider for GenericProvider<P> {
129 fn name(&self) -> &str {
130 self.protocol.name()
131 }
132
133 async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
134 let protocol_request = self.protocol.build_request(request)?;
136
137 let url = self.protocol.chat_endpoint(self.client.base_url());
139
140 let response = self.client.post(&url, &protocol_request).await?;
142 let status = response.status();
143 let text = response.text().await
144 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
145
146 if !status.is_success() {
148 return Err(self.protocol.map_error(status.as_u16(), &text));
149 }
150
151 self.protocol.parse_response(&text)
153 }
154
155 #[cfg(feature = "streaming")]
156 async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
157 let mut streaming_request = request.clone();
158 streaming_request.stream = Some(true);
159
160 let protocol_request = self.protocol.build_request(&streaming_request)?;
161 let url = self.protocol.chat_endpoint(self.client.base_url());
162
163 let response = self.client.stream(&url, &protocol_request).await?;
164 let status = response.status();
165
166 if !status.is_success() {
167 let text = response.text().await
168 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
169 return Err(self.protocol.map_error(status.as_u16(), &text));
170 }
171
172 self.protocol.parse_stream_response(response).await
173 }
174
175 async fn models(&self) -> Result<Vec<String>, LlmConnectorError> {
176 let endpoint = self.protocol.models_endpoint(self.client.base_url())
177 .ok_or_else(|| LlmConnectorError::UnsupportedOperation(
178 format!("{} does not support model listing", self.protocol.name())
179 ))?;
180
181 let response = self.client.get(&endpoint).await?;
182 let status = response.status();
183 let text = response.text().await
184 .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
185
186 if !status.is_success() {
187 return Err(self.protocol.map_error(status.as_u16(), &text));
188 }
189
190 self.protocol.parse_models(&text)
191 }
192
193 fn as_any(&self) -> &dyn Any {
194 self
195 }
196}