llmvm_protocol/
lib.rs

1//! This module contains protocol types and utilities
2//! used for communicating with llmvm core & backends.
3//!
4//! Uses multilink to communicate with local & remote processes.
5
6#[cfg(any(feature = "http-client", feature = "http-server"))]
7pub mod http;
8pub mod service;
9#[cfg(any(feature = "stdio-client", feature = "stdio-server"))]
10pub mod stdio;
11
12pub use async_trait::async_trait;
13pub use multilink::*;
14
15use serde::{Deserialize, Serialize};
16use serde_json::{Map, Value};
17use std::{
18    fmt::{Display, Formatter},
19    str::FromStr,
20};
21use url::Url;
22
23pub const CHAT_MODEL_PROVIDER_SUFFIX: &str = "-chat";
24const CUSTOM_ENDPOINT_PREFIX: &str = "endpoint=";
25
26/// Metadata for a thread.
27#[derive(Clone, Serialize, Deserialize)]
28pub struct ThreadInfo {
29    /// id of the thread.
30    pub id: String,
31    /// Last modified time of the thread.
32    pub modified: String,
33}
34
35/// The actor who presented the message.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "lowercase")]
38pub enum MessageRole {
39    /// For system messages, typically provided as a higher-level prompt for some models.
40    System,
41    /// For user messages, typically prompted by the user.
42    User,
43    /// For assistant message, typically generated by the model.
44    Assistant,
45}
46
47/// A prompt or generated message from a thread.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct Message {
50    /// The actor who presented the message.
51    pub role: MessageRole,
52    /// Text content of the message.
53    pub content: String,
54}
55
56/// The backend service which the core uses to generate text.
57/// Implements a low-level interface for interacting with language models.
58#[async_trait]
59pub trait Backend: Send + Sync {
60    /// Generate text and return the whole response.
61    async fn generate(
62        &self,
63        request: BackendGenerationRequest,
64    ) -> Result<BackendGenerationResponse, ProtocolError>;
65
66    /// Request text generation and return an asynchronous stream of generated tokens.
67    async fn generate_stream(
68        &self,
69        request: BackendGenerationRequest,
70    ) -> Result<NotificationStream<BackendGenerationResponse>, ProtocolError>;
71}
72
73/// The core service which frontends use to interact with language models.
74/// Manages & uses threads, presets, prompt templates and backend connections to create & send
75/// backend requests.
76#[async_trait]
77pub trait Core: Send + Sync {
78    /// Generate text and return the whole response.
79    async fn generate(
80        &self,
81        request: GenerationRequest,
82    ) -> Result<GenerationResponse, ProtocolError>;
83
84    /// Request text generation and return an asynchronous stream of generated tokens.
85    async fn generate_stream(
86        &self,
87        request: GenerationRequest,
88    ) -> Result<NotificationStream<GenerationResponse>, ProtocolError>;
89
90    /// Retrieve information for the last modified thread in the current project
91    /// or user data directory.
92    async fn get_last_thread_info(&self) -> Result<Option<ThreadInfo>, ProtocolError>;
93
94    /// Retrieve information for all available threads.
95    async fn get_all_thread_infos(&self) -> Result<Vec<ThreadInfo>, ProtocolError>;
96
97    /// Retrieve all thread messages for a thread id.
98    async fn get_thread_messages(&self, id: String) -> Result<Vec<Message>, ProtocolError>;
99
100    /// Initialize a new llmvm project in the current directory.
101    fn init_project(&self) -> Result<(), ProtocolError>;
102}
103
104/// Request for language model generation.
105#[derive(Debug, Clone, Default, Serialize, Deserialize)]
106pub struct BackendGenerationRequest {
107    /// The id of the language model.
108    /// The format of the id is `<backend name>/<model provider name>/<model name>`.
109    pub model: String,
110    /// The complete prompt to present to the model.
111    pub prompt: String,
112    /// Maximum amount of tokens to generate.
113    pub max_tokens: u64,
114    /// Optional thread messages from previous requests.
115    pub thread_messages: Option<Vec<Message>>,
116    /// Optional parameters for the model itself. i.e. `temperature`, `top_p`, etc.
117    pub model_parameters: Option<Map<String, Value>>,
118}
119
120/// Response for language model generation.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct BackendGenerationResponse {
123    /// Generated response from language model.
124    pub response: String,
125}
126
127/// Parameters used for generation via core service.
128/// Can be saved in a preset and/or directly provided within the [`GenerationRequest`].
129#[derive(Debug, Default, Clone, Serialize, Deserialize)]
130pub struct GenerationParameters {
131    /// The id of the language model.
132    /// The format of the id is `<backend name>/<model provider name>/<model name>`.
133    pub model: Option<String>,
134    /// An optional id for a saved prompt template to use.
135    pub prompt_template_id: Option<String>,
136    /// Optional text for a custom prompt template. If this is defined
137    /// while `prompt_template_id` is defined, then `prompt_template_id` is ignored.
138    pub custom_prompt_template: Option<String>,
139    /// Maximum amount of tokens to generate.
140    pub max_tokens: Option<u64>,
141    /// Optional parameters for the model itself. i.e. `temperature`, `top_p`, etc.
142    pub model_parameters: Option<Map<String, Value>>,
143    /// Parameters for the prompt template.
144    pub prompt_parameters: Option<Value>,
145}
146
147/// Request for text generation via core service.
148#[derive(Debug, Clone, Default, Serialize, Deserialize)]
149pub struct GenerationRequest {
150    /// An optional id for a saved preset that contains generation parameters.
151    pub preset_id: Option<String>,
152    /// Model generation parameters. If a preset is provided, present parameter fields
153    /// will override the preset values.
154    pub parameters: Option<GenerationParameters>,
155    /// A custom prompt (not a template) to use for generation.
156    pub custom_prompt: Option<String>,
157    /// An existing thread id for loadlng previous messages.
158    pub existing_thread_id: Option<String>,
159    /// If true, the prompt and response will be saved to the existing thread id
160    /// or a new thread.
161    pub save_thread: bool,
162}
163
164/// Response for text generation via core service.
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct GenerationResponse {
167    /// The response generated by the language model.
168    pub response: String,
169    /// Thread id containing the prompt and newly generated response.
170    /// Only provided if `save_thread` is set to true in the associated request.
171    pub thread_id: Option<String>,
172}
173
174/// A parsed model id data structure.
175pub struct ModelDescription {
176    /// Name of the backend to invoke for generation. i.e. `outsource`
177    pub backend: String,
178    /// Name of the provider of the model. i.e. `openai-chat`
179    pub provider: String,
180    /// Name of the model. i.e. `gpt-3.5-turbo`
181    pub model_name: String,
182    /// Custom endpoint (if any)
183    pub endpoint: Option<Url>,
184}
185
186impl ModelDescription {
187    /// Checks if the model is a "chat" model. Currently,
188    /// it checks if the provider name ends with `-chat`.
189    pub fn is_chat_model(&self) -> bool {
190        self.provider.ends_with(CHAT_MODEL_PROVIDER_SUFFIX)
191    }
192}
193
194impl FromStr for ModelDescription {
195    type Err = ();
196
197    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
198        let split = s.split("/");
199        let tokens: Vec<String> = split.map(|v| v.to_string()).collect();
200        if tokens.len() < 3 || tokens[..3].iter().any(|v| v.is_empty()) {
201            return Err(());
202        }
203        let mut tokens_iter = tokens.into_iter();
204        let backend = tokens_iter.next().unwrap();
205        let provider = tokens_iter.next().unwrap();
206        let mut model_name = tokens_iter.collect::<Vec<String>>().join("/");
207        let mut endpoint = None;
208        if let Some(endpoint_idx) = model_name.rfind(CUSTOM_ENDPOINT_PREFIX) {
209            let endpoint_str = &model_name[endpoint_idx + CUSTOM_ENDPOINT_PREFIX.len()..];
210            endpoint = Some(Url::parse(endpoint_str).map_err(|_| ())?);
211            model_name = model_name[..endpoint_idx].to_string();
212            if model_name.ends_with("/") {
213                model_name.pop();
214            }
215        }
216
217        Ok(Self {
218            backend,
219            provider,
220            model_name,
221            endpoint,
222        })
223    }
224}
225
226impl Display for ModelDescription {
227    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
228        if let Some(endpoint) = self.endpoint.as_ref() {
229            if self.model_name.is_empty() {
230                write!(
231                    f,
232                    "{}/{}/{}{}",
233                    self.backend, self.provider, CUSTOM_ENDPOINT_PREFIX, endpoint
234                )
235            } else {
236                write!(
237                    f,
238                    "{}/{}/{}/{}{}",
239                    self.backend, self.provider, self.model_name, CUSTOM_ENDPOINT_PREFIX, endpoint
240                )
241            }
242        } else {
243            write!(f, "{}/{}/{}", self.backend, self.provider, self.model_name)
244        }
245    }
246}