1use crate::{types::Content, Result};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7
8pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
9
10#[async_trait]
11pub trait Llm: Send + Sync {
12 fn name(&self) -> &str;
13 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LlmRequest {
18 pub model: String,
19 pub contents: Vec<Content>,
20 pub config: Option<GenerateContentConfig>,
21 #[serde(skip)]
22 pub tools: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct GenerateContentConfig {
27 pub temperature: Option<f32>,
28 pub top_p: Option<f32>,
29 pub top_k: Option<i32>,
30 pub max_output_tokens: Option<i32>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub response_schema: Option<serde_json::Value>,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36pub struct LlmResponse {
37 pub content: Option<Content>,
38 pub usage_metadata: Option<UsageMetadata>,
39 pub finish_reason: Option<FinishReason>,
40 pub partial: bool,
41 pub turn_complete: bool,
42 pub interrupted: bool,
43 pub error_code: Option<String>,
44 pub error_message: Option<String>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct UsageMetadata {
49 pub prompt_token_count: i32,
50 pub candidates_token_count: i32,
51 pub total_token_count: i32,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55pub enum FinishReason {
56 Stop,
57 MaxTokens,
58 Safety,
59 Recitation,
60 Other,
61}
62
63impl LlmRequest {
64 pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
65 Self { model: model.into(), contents, config: None, tools: HashMap::new() }
66 }
67}
68
69impl LlmResponse {
70 pub fn new(content: Content) -> Self {
71 Self {
72 content: Some(content),
73 usage_metadata: None,
74 finish_reason: Some(FinishReason::Stop),
75 partial: false,
76 turn_complete: true,
77 interrupted: false,
78 error_code: None,
79 error_message: None,
80 }
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 #[test]
89 fn test_llm_request_creation() {
90 let req = LlmRequest::new("test-model", vec![]);
91 assert_eq!(req.model, "test-model");
92 assert!(req.contents.is_empty());
93 }
94
95 #[test]
96 fn test_llm_response_creation() {
97 let content = Content::new("assistant");
98 let resp = LlmResponse::new(content);
99 assert!(resp.content.is_some());
100 assert!(resp.turn_complete);
101 assert!(!resp.partial);
102 assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
103 }
104
105 #[test]
106 fn test_finish_reason() {
107 assert_eq!(FinishReason::Stop, FinishReason::Stop);
108 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
109 }
110}