Skip to main content

cognee_llm/
types.rs

1//! Common types for LLM operations.
2
3use serde::{Deserialize, Serialize};
4
5/// Message role in a conversation.
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum MessageRole {
9    System,
10    User,
11    Assistant,
12}
13
14/// A message in a conversation.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Message {
17    pub role: MessageRole,
18    pub content: String,
19}
20
21impl Message {
22    pub fn system(content: impl Into<String>) -> Self {
23        Self {
24            role: MessageRole::System,
25            content: content.into(),
26        }
27    }
28
29    pub fn user(content: impl Into<String>) -> Self {
30        Self {
31            role: MessageRole::User,
32            content: content.into(),
33        }
34    }
35
36    pub fn assistant(content: impl Into<String>) -> Self {
37        Self {
38            role: MessageRole::Assistant,
39            content: content.into(),
40        }
41    }
42}
43
44/// Options for LLM generation.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct GenerationOptions {
47    /// Temperature for sampling (0.0 = deterministic, 1.0 = creative).
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub temperature: Option<f32>,
50
51    /// Maximum number of tokens to generate.
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub max_tokens: Option<u32>,
54
55    /// Top-p sampling parameter.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub top_p: Option<f32>,
58
59    /// Frequency penalty.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub frequency_penalty: Option<f32>,
62
63    /// Presence penalty.
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub presence_penalty: Option<f32>,
66
67    /// Stop sequences.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub stop: Option<Vec<String>>,
70}
71
72impl Default for GenerationOptions {
73    fn default() -> Self {
74        Self {
75            temperature: Some(0.0),
76            max_tokens: Some(16384),
77            top_p: None,
78            frequency_penalty: None,
79            presence_penalty: None,
80            stop: None,
81        }
82    }
83}
84
85/// Response from LLM generation.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct GenerationResponse {
88    /// Generated text content.
89    pub content: String,
90
91    /// Model used for generation.
92    pub model: String,
93
94    /// Token usage information.
95    pub usage: Option<TokenUsage>,
96
97    /// Finish reason (e.g., "stop", "length", "content_filter").
98    pub finish_reason: Option<String>,
99}
100
101/// Token usage statistics.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct TokenUsage {
104    pub prompt_tokens: u32,
105    pub completion_tokens: u32,
106    pub total_tokens: u32,
107}