1use comp_cat_rs::effect::io::Io;
7use comp_cat_rs::effect::stream::Stream;
8use serde::{Deserialize, Serialize};
9
10use crate::error::Error;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15 role: Role,
16 content: String,
17}
18
19impl Message {
20 #[must_use]
21 pub fn new(role: Role, content: String) -> Self {
22 Self { role, content }
23 }
24
25 #[must_use]
26 pub fn role(&self) -> &Role { &self.role }
27
28 #[must_use]
29 pub fn content(&self) -> &str { &self.content }
30
31 #[must_use]
32 pub fn system(content: String) -> Self {
33 Self::new(Role::System, content)
34 }
35
36 #[must_use]
37 pub fn user(content: String) -> Self {
38 Self::new(Role::User, content)
39 }
40
41 #[must_use]
42 pub fn assistant(content: String) -> Self {
43 Self::new(Role::Assistant, content)
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(rename_all = "lowercase")]
50pub enum Role {
51 System,
52 User,
53 Assistant,
54}
55
56#[derive(Debug, Clone)]
58pub struct CompletionRequest {
59 messages: Vec<Message>,
60 temperature: Option<f64>,
61 max_tokens: Option<u32>,
62}
63
64impl CompletionRequest {
65 #[must_use]
66 pub fn new(messages: Vec<Message>) -> Self {
67 Self {
68 messages,
69 temperature: None,
70 max_tokens: None,
71 }
72 }
73
74 #[must_use]
75 pub fn messages(&self) -> &[Message] { &self.messages }
76
77 #[must_use]
78 pub fn temperature(&self) -> Option<f64> { self.temperature }
79
80 #[must_use]
81 pub fn max_tokens(&self) -> Option<u32> { self.max_tokens }
82
83 #[must_use]
84 pub fn with_temperature(self, t: f64) -> Self {
85 Self { temperature: Some(t), ..self }
86 }
87
88 #[must_use]
89 pub fn with_max_tokens(self, n: u32) -> Self {
90 Self { max_tokens: Some(n), ..self }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct CompletionResponse {
97 content: String,
98 model: String,
99}
100
101impl CompletionResponse {
102 #[must_use]
103 pub fn new(content: String, model: String) -> Self {
104 Self { content, model }
105 }
106
107 #[must_use]
108 pub fn content(&self) -> &str { &self.content }
109
110 #[must_use]
111 pub fn model(&self) -> &str { &self.model }
112}
113
114#[derive(Debug, Clone)]
116pub struct StreamChunk {
117 delta: String,
118}
119
120impl StreamChunk {
121 #[must_use]
122 pub fn new(delta: String) -> Self { Self { delta } }
123
124 #[must_use]
125 pub fn delta(&self) -> &str { &self.delta }
126}
127
128pub trait CompletionModel {
133 fn complete(&self, request: CompletionRequest) -> Io<Error, CompletionResponse>;
135
136 fn stream(&self, request: CompletionRequest) -> Stream<Error, StreamChunk>;
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn message_constructors_set_correct_roles() {
146 let sys = Message::system("sys".into());
147 let usr = Message::user("usr".into());
148 let ast = Message::assistant("ast".into());
149 assert!(matches!(sys.role(), Role::System));
150 assert!(matches!(usr.role(), Role::User));
151 assert!(matches!(ast.role(), Role::Assistant));
152 assert_eq!(sys.content(), "sys");
153 assert_eq!(usr.content(), "usr");
154 assert_eq!(ast.content(), "ast");
155 }
156
157 #[test]
158 fn completion_request_builder_applies_options() {
159 let req = CompletionRequest::new(vec![Message::user("hi".into())])
160 .with_temperature(0.5)
161 .with_max_tokens(100);
162 assert_eq!(req.messages().len(), 1);
163 assert!((req.temperature().unwrap_or(0.0) - 0.5).abs() < 1e-10);
164 assert_eq!(req.max_tokens(), Some(100));
165 }
166
167 #[test]
168 fn completion_request_defaults_are_none() {
169 let req = CompletionRequest::new(vec![]);
170 assert!(req.temperature().is_none());
171 assert!(req.max_tokens().is_none());
172 }
173}