Skip to main content

rig_cat/model/
mod.rs

1//! Completion model trait: the core LLM abstraction.
2//!
3//! A `CompletionModel` takes a prompt and returns a response,
4//! wrapped in `Io` for composable effect handling.
5
6use comp_cat_rs::effect::io::Io;
7use comp_cat_rs::effect::stream::Stream;
8use serde::{Deserialize, Serialize};
9
10use crate::error::Error;
11
12/// A message in a conversation.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15    role: Role,
16    content: String,
17}
18
19impl Message {
20    #[must_use]
21    pub fn new(role: Role, content: String) -> Self {
22        Self { role, content }
23    }
24
25    #[must_use]
26    pub fn role(&self) -> &Role { &self.role }
27
28    #[must_use]
29    pub fn content(&self) -> &str { &self.content }
30
31    #[must_use]
32    pub fn system(content: String) -> Self {
33        Self::new(Role::System, content)
34    }
35
36    #[must_use]
37    pub fn user(content: String) -> Self {
38        Self::new(Role::User, content)
39    }
40
41    #[must_use]
42    pub fn assistant(content: String) -> Self {
43        Self::new(Role::Assistant, content)
44    }
45}
46
47/// The role of a message sender.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(rename_all = "lowercase")]
50pub enum Role {
51    System,
52    User,
53    Assistant,
54}
55
56/// A completion request.
57#[derive(Debug, Clone)]
58pub struct CompletionRequest {
59    messages: Vec<Message>,
60    temperature: Option<f64>,
61    max_tokens: Option<u32>,
62}
63
64impl CompletionRequest {
65    #[must_use]
66    pub fn new(messages: Vec<Message>) -> Self {
67        Self {
68            messages,
69            temperature: None,
70            max_tokens: None,
71        }
72    }
73
74    #[must_use]
75    pub fn messages(&self) -> &[Message] { &self.messages }
76
77    #[must_use]
78    pub fn temperature(&self) -> Option<f64> { self.temperature }
79
80    #[must_use]
81    pub fn max_tokens(&self) -> Option<u32> { self.max_tokens }
82
83    #[must_use]
84    pub fn with_temperature(self, t: f64) -> Self {
85        Self { temperature: Some(t), ..self }
86    }
87
88    #[must_use]
89    pub fn with_max_tokens(self, n: u32) -> Self {
90        Self { max_tokens: Some(n), ..self }
91    }
92}
93
94/// A completion response.
95#[derive(Debug, Clone)]
96pub struct CompletionResponse {
97    content: String,
98    model: String,
99}
100
101impl CompletionResponse {
102    #[must_use]
103    pub fn new(content: String, model: String) -> Self {
104        Self { content, model }
105    }
106
107    #[must_use]
108    pub fn content(&self) -> &str { &self.content }
109
110    #[must_use]
111    pub fn model(&self) -> &str { &self.model }
112}
113
114/// A token chunk from a streaming response.
115#[derive(Debug, Clone)]
116pub struct StreamChunk {
117    delta: String,
118}
119
120impl StreamChunk {
121    #[must_use]
122    pub fn new(delta: String) -> Self { Self { delta } }
123
124    #[must_use]
125    pub fn delta(&self) -> &str { &self.delta }
126}
127
128/// The core LLM abstraction: send a request, get a response.
129///
130/// All provider-specific logic is behind this trait.
131/// Returns `Io<Error, _>` for composable effect handling.
132pub trait CompletionModel {
133    /// Send a completion request and get a full response.
134    fn complete(&self, request: CompletionRequest) -> Io<Error, CompletionResponse>;
135
136    /// Send a completion request and get a streaming response.
137    fn stream(&self, request: CompletionRequest) -> Stream<Error, StreamChunk>;
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn message_constructors_set_correct_roles() {
146        let sys = Message::system("sys".into());
147        let usr = Message::user("usr".into());
148        let ast = Message::assistant("ast".into());
149        assert!(matches!(sys.role(), Role::System));
150        assert!(matches!(usr.role(), Role::User));
151        assert!(matches!(ast.role(), Role::Assistant));
152        assert_eq!(sys.content(), "sys");
153        assert_eq!(usr.content(), "usr");
154        assert_eq!(ast.content(), "ast");
155    }
156
157    #[test]
158    fn completion_request_builder_applies_options() {
159        let req = CompletionRequest::new(vec![Message::user("hi".into())])
160            .with_temperature(0.5)
161            .with_max_tokens(100);
162        assert_eq!(req.messages().len(), 1);
163        assert!((req.temperature().unwrap_or(0.0) - 0.5).abs() < 1e-10);
164        assert_eq!(req.max_tokens(), Some(100));
165    }
166
167    #[test]
168    fn completion_request_defaults_are_none() {
169        let req = CompletionRequest::new(vec![]);
170        assert!(req.temperature().is_none());
171        assert!(req.max_tokens().is_none());
172    }
173}