1use serde::{Deserialize, Serialize};
6use tracing::warn;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16 System,
18 User,
20 Assistant,
22 Tool,
24}
25
26impl std::fmt::Display for Role {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 Role::System => write!(f, "system"),
30 Role::User => write!(f, "user"),
31 Role::Assistant => write!(f, "assistant"),
32 Role::Tool => write!(f, "tool"),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ChatMessage {
40 pub role: Role,
42 pub content: String,
44}
45
46impl ChatMessage {
47 pub fn system(content: impl Into<String>) -> Self {
49 Self {
50 role: Role::System,
51 content: content.into(),
52 }
53 }
54
55 pub fn user(content: impl Into<String>) -> Self {
57 Self {
58 role: Role::User,
59 content: content.into(),
60 }
61 }
62
63 pub fn assistant(content: impl Into<String>) -> Self {
65 Self {
66 role: Role::Assistant,
67 content: content.into(),
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
81pub struct ChatRequest {
82 pub messages: Vec<ChatMessage>,
84 pub model: Option<String>,
86 pub temperature: Option<f32>,
88 pub max_tokens: Option<u32>,
90 pub top_p: Option<f32>,
92 pub stop: Vec<String>,
94}
95
96impl ChatRequest {
97 pub fn builder() -> ChatRequestBuilder {
99 ChatRequestBuilder::default()
100 }
101
102 pub fn prompt(content: impl Into<String>) -> Self {
104 Self {
105 messages: vec![ChatMessage::user(content)],
106 model: None,
107 temperature: None,
108 max_tokens: None,
109 top_p: None,
110 stop: Vec::new(),
111 }
112 }
113
114 pub fn with_system(system: impl Into<String>, prompt: impl Into<String>) -> Self {
116 Self {
117 messages: vec![ChatMessage::system(system), ChatMessage::user(prompt)],
118 model: None,
119 temperature: None,
120 max_tokens: None,
121 top_p: None,
122 stop: Vec::new(),
123 }
124 }
125}
126
127#[derive(Debug, Default)]
129pub struct ChatRequestBuilder {
130 messages: Vec<ChatMessage>,
131 model: Option<String>,
132 temperature: Option<f32>,
133 max_tokens: Option<u32>,
134 top_p: Option<f32>,
135 stop: Vec<String>,
136}
137
138impl ChatRequestBuilder {
139 pub fn message(mut self, msg: ChatMessage) -> Self {
141 self.messages.push(msg);
142 self
143 }
144
145 pub fn messages(mut self, msgs: impl IntoIterator<Item = ChatMessage>) -> Self {
147 self.messages.extend(msgs);
148 self
149 }
150
151 pub fn system(self, content: impl Into<String>) -> Self {
153 self.message(ChatMessage::system(content))
154 }
155
156 pub fn user(self, content: impl Into<String>) -> Self {
158 self.message(ChatMessage::user(content))
159 }
160
161 pub fn model(mut self, model: impl Into<String>) -> Self {
163 self.model = Some(model.into());
164 self
165 }
166
167 pub fn temperature(mut self, t: f32) -> Self {
169 let clamped = t.clamp(0.0, 2.0);
170 if (clamped - t).abs() > f32::EPSILON {
171 warn!(
172 requested = t,
173 clamped = clamped,
174 "temperature value {t} out of range [0.0, 2.0], clamped to {clamped}"
175 );
176 }
177 self.temperature = Some(clamped);
178 self
179 }
180
181 pub fn max_tokens(mut self, n: u32) -> Self {
183 self.max_tokens = Some(n);
184 self
185 }
186
187 pub fn top_p(mut self, p: f32) -> Self {
189 let clamped = p.clamp(0.0, 1.0);
190 if (clamped - p).abs() > f32::EPSILON {
191 warn!(
192 requested = p,
193 clamped = clamped,
194 "top_p value {p} out of range [0.0, 1.0], clamped to {clamped}"
195 );
196 }
197 self.top_p = Some(clamped);
198 self
199 }
200
201 pub fn stop(mut self, s: impl Into<String>) -> Self {
203 self.stop.push(s.into());
204 self
205 }
206
207 pub fn build(self) -> ChatRequest {
209 ChatRequest {
210 messages: self.messages,
211 model: self.model,
212 temperature: self.temperature,
213 max_tokens: self.max_tokens,
214 top_p: self.top_p,
215 stop: self.stop,
216 }
217 }
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
222#[serde(rename_all = "snake_case")]
223pub enum FinishReason {
224 Stop,
226 Length,
228 ContentFilter,
230 ToolCalls,
232}
233
234impl std::fmt::Display for FinishReason {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self {
237 FinishReason::Stop => write!(f, "stop"),
238 FinishReason::Length => write!(f, "length"),
239 FinishReason::ContentFilter => write!(f, "content_filter"),
240 FinishReason::ToolCalls => write!(f, "tool_calls"),
241 }
242 }
243}
244
245#[derive(Debug, Clone)]
247pub struct ChatChoice {
248 pub index: u32,
250 pub message: ChatMessage,
252 pub finish_reason: FinishReason,
254}
255
256#[derive(Debug, Clone, Copy, Default)]
258pub struct Usage {
259 pub prompt_tokens: u32,
261 pub completion_tokens: u32,
263 pub total_tokens: u32,
265}
266
267#[derive(Debug, Clone)]
269pub struct ChatResponse {
270 pub id: String,
272 pub model: String,
274 pub choices: Vec<ChatChoice>,
276 pub usage: Usage,
278}
279
280impl ChatResponse {
281 pub fn content(&self) -> &str {
283 self.choices
284 .first()
285 .map(|c| c.message.content.as_str())
286 .unwrap_or("")
287 }
288
289 pub fn finish_reason(&self) -> Option<FinishReason> {
291 self.choices.first().map(|c| c.finish_reason)
292 }
293}
294
295#[derive(Debug, Clone)]
301pub struct EmbeddingRequest {
302 pub input: Vec<String>,
304 pub model: Option<String>,
306 pub dimensions: Option<u32>,
308 pub input_type: Option<String>,
312}
313
314impl EmbeddingRequest {
315 pub fn builder() -> EmbeddingRequestBuilder {
317 EmbeddingRequestBuilder::default()
318 }
319
320 pub fn single(text: impl Into<String>) -> Self {
322 Self {
323 input: vec![text.into()],
324 model: None,
325 dimensions: None,
326 input_type: None,
327 }
328 }
329
330 pub fn batch(texts: impl IntoIterator<Item = impl Into<String>>) -> Self {
332 Self {
333 input: texts.into_iter().map(Into::into).collect(),
334 model: None,
335 dimensions: None,
336 input_type: None,
337 }
338 }
339}
340
341#[derive(Debug, Default)]
343pub struct EmbeddingRequestBuilder {
344 input: Vec<String>,
345 model: Option<String>,
346 dimensions: Option<u32>,
347 input_type: Option<String>,
348}
349
350impl EmbeddingRequestBuilder {
351 pub fn input(mut self, text: impl Into<String>) -> Self {
353 self.input.push(text.into());
354 self
355 }
356
357 pub fn inputs(mut self, texts: impl IntoIterator<Item = impl Into<String>>) -> Self {
359 self.input.extend(texts.into_iter().map(Into::into));
360 self
361 }
362
363 pub fn model(mut self, model: impl Into<String>) -> Self {
365 self.model = Some(model.into());
366 self
367 }
368
369 pub fn dimensions(mut self, d: u32) -> Self {
371 self.dimensions = Some(d);
372 self
373 }
374
375 pub fn input_type(mut self, t: impl Into<String>) -> Self {
377 self.input_type = Some(t.into());
378 self
379 }
380
381 pub fn build(self) -> EmbeddingRequest {
383 EmbeddingRequest {
384 input: self.input,
385 model: self.model,
386 dimensions: self.dimensions,
387 input_type: self.input_type,
388 }
389 }
390}
391
392#[derive(Debug, Clone)]
394pub struct Embedding {
395 pub index: u32,
397 pub values: Vec<f32>,
399}
400
401impl Embedding {
402 pub fn dimensions(&self) -> usize {
404 self.values.len()
405 }
406}
407
408#[derive(Debug, Clone, Copy, Default)]
410pub struct EmbeddingUsage {
411 pub prompt_tokens: u32,
413 pub total_tokens: u32,
415}
416
417#[derive(Debug, Clone)]
419pub struct EmbeddingResponse {
420 pub model: String,
422 pub embeddings: Vec<Embedding>,
424 pub usage: EmbeddingUsage,
426}
427
428impl EmbeddingResponse {
429 pub fn first_embedding(&self) -> Option<&[f32]> {
431 self.embeddings.first().map(|e| e.values.as_slice())
432 }
433
434 pub fn len(&self) -> usize {
436 self.embeddings.len()
437 }
438
439 pub fn is_empty(&self) -> bool {
441 self.embeddings.is_empty()
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_chat_message_constructors() {
451 let sys = ChatMessage::system("You are helpful.");
452 assert_eq!(sys.role, Role::System);
453 assert_eq!(sys.content, "You are helpful.");
454
455 let usr = ChatMessage::user("Hello");
456 assert_eq!(usr.role, Role::User);
457
458 let ast = ChatMessage::assistant("Hi there!");
459 assert_eq!(ast.role, Role::Assistant);
460 }
461
462 #[test]
463 fn test_chat_request_builder() {
464 let req = ChatRequest::builder()
465 .system("Be concise.")
466 .user("What is Rust?")
467 .temperature(0.5)
468 .max_tokens(100)
469 .model("gpt-4o")
470 .build();
471
472 assert_eq!(req.messages.len(), 2);
473 assert_eq!(req.messages[0].role, Role::System);
474 assert_eq!(req.messages[1].role, Role::User);
475 assert_eq!(req.temperature, Some(0.5));
476 assert_eq!(req.max_tokens, Some(100));
477 assert_eq!(req.model.as_deref(), Some("gpt-4o"));
478 }
479
480 #[test]
481 fn test_chat_request_prompt() {
482 let req = ChatRequest::prompt("Hello");
483 assert_eq!(req.messages.len(), 1);
484 assert_eq!(req.messages[0].role, Role::User);
485 assert_eq!(req.messages[0].content, "Hello");
486 }
487
488 #[test]
489 fn test_chat_request_with_system() {
490 let req = ChatRequest::with_system("Be brief.", "Hi");
491 assert_eq!(req.messages.len(), 2);
492 assert_eq!(req.messages[0].role, Role::System);
493 }
494
495 #[test]
496 fn test_temperature_clamping() {
497 let req = ChatRequest::builder().temperature(5.0).build();
498 assert_eq!(req.temperature, Some(2.0));
499
500 let req = ChatRequest::builder().temperature(-1.0).build();
501 assert_eq!(req.temperature, Some(0.0));
502 }
503
504 #[test]
505 fn test_top_p_clamping() {
506 let req = ChatRequest::builder().top_p(1.5).build();
507 assert_eq!(req.top_p, Some(1.0));
508 }
509
510 #[test]
511 fn test_chat_response_content() {
512 let resp = ChatResponse {
513 id: "test".to_string(),
514 model: "gpt-4o".to_string(),
515 choices: vec![ChatChoice {
516 index: 0,
517 message: ChatMessage::assistant("Hello!"),
518 finish_reason: FinishReason::Stop,
519 }],
520 usage: Usage {
521 prompt_tokens: 5,
522 completion_tokens: 1,
523 total_tokens: 6,
524 },
525 };
526 assert_eq!(resp.content(), "Hello!");
527 assert_eq!(resp.finish_reason(), Some(FinishReason::Stop));
528 }
529
530 #[test]
531 fn test_chat_response_empty() {
532 let resp = ChatResponse {
533 id: "test".to_string(),
534 model: "test".to_string(),
535 choices: vec![],
536 usage: Usage::default(),
537 };
538 assert_eq!(resp.content(), "");
539 assert_eq!(resp.finish_reason(), None);
540 }
541
542 #[test]
543 fn test_embedding_request_single() {
544 let req = EmbeddingRequest::single("Hello world");
545 assert_eq!(req.input.len(), 1);
546 assert_eq!(req.input[0], "Hello world");
547 }
548
549 #[test]
550 fn test_embedding_request_batch() {
551 let req = EmbeddingRequest::batch(["one", "two", "three"]);
552 assert_eq!(req.input.len(), 3);
553 }
554
555 #[test]
556 fn test_embedding_request_builder() {
557 let req = EmbeddingRequest::builder()
558 .input("hello")
559 .input("world")
560 .model("text-embedding-3-small")
561 .dimensions(256)
562 .build();
563 assert_eq!(req.input.len(), 2);
564 assert_eq!(req.model.as_deref(), Some("text-embedding-3-small"));
565 assert_eq!(req.dimensions, Some(256));
566 assert!(req.input_type.is_none());
567 }
568
569 #[test]
570 fn test_embedding_request_builder_with_input_type() {
571 let req = EmbeddingRequest::builder()
572 .input("query")
573 .model("cohere.embed-english-v3")
574 .input_type("search_query")
575 .build();
576 assert_eq!(req.input_type.as_deref(), Some("search_query"));
577 }
578
579 #[test]
580 fn test_embedding_response_first() {
581 let resp = EmbeddingResponse {
582 model: "test".to_string(),
583 embeddings: vec![Embedding {
584 index: 0,
585 values: vec![0.1, 0.2, 0.3],
586 }],
587 usage: EmbeddingUsage {
588 prompt_tokens: 2,
589 total_tokens: 2,
590 },
591 };
592 assert_eq!(resp.first_embedding(), Some([0.1, 0.2, 0.3].as_slice()));
593 assert_eq!(resp.len(), 1);
594 assert!(!resp.is_empty());
595 assert_eq!(resp.embeddings[0].dimensions(), 3);
596 }
597
598 #[test]
599 fn test_embedding_response_empty() {
600 let resp = EmbeddingResponse {
601 model: "test".to_string(),
602 embeddings: vec![],
603 usage: EmbeddingUsage::default(),
604 };
605 assert!(resp.is_empty());
606 assert_eq!(resp.first_embedding(), None);
607 }
608
609 #[test]
610 fn test_role_display() {
611 assert_eq!(Role::System.to_string(), "system");
612 assert_eq!(Role::User.to_string(), "user");
613 assert_eq!(Role::Assistant.to_string(), "assistant");
614 assert_eq!(Role::Tool.to_string(), "tool");
615 }
616
617 #[test]
618 fn test_finish_reason_display() {
619 assert_eq!(FinishReason::Stop.to_string(), "stop");
620 assert_eq!(FinishReason::Length.to_string(), "length");
621 assert_eq!(FinishReason::ContentFilter.to_string(), "content_filter");
622 assert_eq!(FinishReason::ToolCalls.to_string(), "tool_calls");
623 }
624
625 #[test]
626 fn test_role_serde_roundtrip() {
627 let json = serde_json::to_string(&Role::System).unwrap();
628 assert_eq!(json, r#""system""#);
629 let back: Role = serde_json::from_str(&json).unwrap();
630 assert_eq!(back, Role::System);
631 }
632
633 #[test]
634 fn test_finish_reason_serde_roundtrip() {
635 let json = serde_json::to_string(&FinishReason::ContentFilter).unwrap();
636 assert_eq!(json, r#""content_filter""#);
637 let back: FinishReason = serde_json::from_str(&json).unwrap();
638 assert_eq!(back, FinishReason::ContentFilter);
639 }
640
641 #[test]
642 fn test_usage_default() {
643 let u = Usage::default();
644 assert_eq!(u.prompt_tokens, 0);
645 assert_eq!(u.completion_tokens, 0);
646 assert_eq!(u.total_tokens, 0);
647 }
648}