1#[cfg(any(feature = "http-client", feature = "http-server"))]
7pub mod http;
8pub mod service;
9#[cfg(any(feature = "stdio-client", feature = "stdio-server"))]
10pub mod stdio;
11
12pub use async_trait::async_trait;
13pub use multilink::*;
14
15use serde::{Deserialize, Serialize};
16use serde_json::{Map, Value};
17use std::{
18 fmt::{Display, Formatter},
19 str::FromStr,
20};
21use url::Url;
22
23pub const CHAT_MODEL_PROVIDER_SUFFIX: &str = "-chat";
24const CUSTOM_ENDPOINT_PREFIX: &str = "endpoint=";
25
26#[derive(Clone, Serialize, Deserialize)]
28pub struct ThreadInfo {
29 pub id: String,
31 pub modified: String,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "lowercase")]
38pub enum MessageRole {
39 System,
41 User,
43 Assistant,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct Message {
50 pub role: MessageRole,
52 pub content: String,
54}
55
56#[async_trait]
59pub trait Backend: Send + Sync {
60 async fn generate(
62 &self,
63 request: BackendGenerationRequest,
64 ) -> Result<BackendGenerationResponse, ProtocolError>;
65
66 async fn generate_stream(
68 &self,
69 request: BackendGenerationRequest,
70 ) -> Result<NotificationStream<BackendGenerationResponse>, ProtocolError>;
71}
72
73#[async_trait]
77pub trait Core: Send + Sync {
78 async fn generate(
80 &self,
81 request: GenerationRequest,
82 ) -> Result<GenerationResponse, ProtocolError>;
83
84 async fn generate_stream(
86 &self,
87 request: GenerationRequest,
88 ) -> Result<NotificationStream<GenerationResponse>, ProtocolError>;
89
90 async fn get_last_thread_info(&self) -> Result<Option<ThreadInfo>, ProtocolError>;
93
94 async fn get_all_thread_infos(&self) -> Result<Vec<ThreadInfo>, ProtocolError>;
96
97 async fn get_thread_messages(&self, id: String) -> Result<Vec<Message>, ProtocolError>;
99
100 fn init_project(&self) -> Result<(), ProtocolError>;
102}
103
104#[derive(Debug, Clone, Default, Serialize, Deserialize)]
106pub struct BackendGenerationRequest {
107 pub model: String,
110 pub prompt: String,
112 pub max_tokens: u64,
114 pub thread_messages: Option<Vec<Message>>,
116 pub model_parameters: Option<Map<String, Value>>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct BackendGenerationResponse {
123 pub response: String,
125}
126
127#[derive(Debug, Default, Clone, Serialize, Deserialize)]
130pub struct GenerationParameters {
131 pub model: Option<String>,
134 pub prompt_template_id: Option<String>,
136 pub custom_prompt_template: Option<String>,
139 pub max_tokens: Option<u64>,
141 pub model_parameters: Option<Map<String, Value>>,
143 pub prompt_parameters: Option<Value>,
145}
146
147#[derive(Debug, Clone, Default, Serialize, Deserialize)]
149pub struct GenerationRequest {
150 pub preset_id: Option<String>,
152 pub parameters: Option<GenerationParameters>,
155 pub custom_prompt: Option<String>,
157 pub existing_thread_id: Option<String>,
159 pub save_thread: bool,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct GenerationResponse {
167 pub response: String,
169 pub thread_id: Option<String>,
172}
173
174pub struct ModelDescription {
176 pub backend: String,
178 pub provider: String,
180 pub model_name: String,
182 pub endpoint: Option<Url>,
184}
185
186impl ModelDescription {
187 pub fn is_chat_model(&self) -> bool {
190 self.provider.ends_with(CHAT_MODEL_PROVIDER_SUFFIX)
191 }
192}
193
194impl FromStr for ModelDescription {
195 type Err = ();
196
197 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
198 let split = s.split("/");
199 let tokens: Vec<String> = split.map(|v| v.to_string()).collect();
200 if tokens.len() < 3 || tokens[..3].iter().any(|v| v.is_empty()) {
201 return Err(());
202 }
203 let mut tokens_iter = tokens.into_iter();
204 let backend = tokens_iter.next().unwrap();
205 let provider = tokens_iter.next().unwrap();
206 let mut model_name = tokens_iter.collect::<Vec<String>>().join("/");
207 let mut endpoint = None;
208 if let Some(endpoint_idx) = model_name.rfind(CUSTOM_ENDPOINT_PREFIX) {
209 let endpoint_str = &model_name[endpoint_idx + CUSTOM_ENDPOINT_PREFIX.len()..];
210 endpoint = Some(Url::parse(endpoint_str).map_err(|_| ())?);
211 model_name = model_name[..endpoint_idx].to_string();
212 if model_name.ends_with("/") {
213 model_name.pop();
214 }
215 }
216
217 Ok(Self {
218 backend,
219 provider,
220 model_name,
221 endpoint,
222 })
223 }
224}
225
226impl Display for ModelDescription {
227 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
228 if let Some(endpoint) = self.endpoint.as_ref() {
229 if self.model_name.is_empty() {
230 write!(
231 f,
232 "{}/{}/{}{}",
233 self.backend, self.provider, CUSTOM_ENDPOINT_PREFIX, endpoint
234 )
235 } else {
236 write!(
237 f,
238 "{}/{}/{}/{}{}",
239 self.backend, self.provider, self.model_name, CUSTOM_ENDPOINT_PREFIX, endpoint
240 )
241 }
242 } else {
243 write!(f, "{}/{}/{}", self.backend, self.provider, self.model_name)
244 }
245 }
246}