kodegen_simd/
config.rs

1//! Configuration types for SIMD-accelerated operations
2
3/// Configuration for logits processing operations
4#[derive(Debug, Clone, PartialEq)]
5pub struct ProcessorConfig {
6    /// Controls randomness (1.0 = no change, < 1.0 = less random, > 1.0 = more random)
7    pub temperature: f32,
8
9    /// Number of highest probability tokens to keep (None = keep all)
10    pub top_k: Option<usize>,
11
12    /// Nucleus sampling parameter (`None` = disabled, `0.0 < top_p <= 1.0`)
13    pub top_p: Option<f32>,
14
15    /// Penalty for repeated tokens (1.0 = no penalty, > 1.0 = more penalty)
16    pub repetition_penalty: f32,
17
18    /// Penalty based on token frequency (0.0 = no penalty, > 0.0 = more penalty)
19    pub frequency_penalty: f32,
20
21    /// Penalty for tokens present in context (0.0 = no penalty, > 0.0 = more penalty)
22    pub presence_penalty: f32,
23}
24
25impl Default for ProcessorConfig {
26    fn default() -> Self {
27        Self {
28            temperature: 1.0,
29            top_k: None,
30            top_p: None,
31            repetition_penalty: 1.0,
32            frequency_penalty: 0.0,
33            presence_penalty: 0.0,
34        }
35    }
36}
37
38/// Error type for configuration validation
39#[derive(Debug, thiserror::Error)]
40pub enum ConfigError {
41    /// Invalid temperature value - must be positive
42    #[error("Invalid temperature value: {0}. Must be positive")]
43    InvalidTemperature(f32),
44
45    /// Invalid `top_k` value - must be `> 0` if set
46    #[error("Invalid top_k value: {0}. Must be > 0 if set")]
47    InvalidTopK(usize),
48
49    /// Invalid `top_p` value - must be in range `(0.0, 1.0]`
50    #[error("Invalid top_p value: {0}. Must be in range (0.0, 1.0]")]
51    InvalidTopP(f32),
52
53    /// Invalid repetition penalty - must be >= 1.0
54    #[error("Invalid repetition penalty: {0}. Must be >= 1.0")]
55    InvalidRepetitionPenalty(f32),
56
57    /// Invalid frequency penalty - must be >= 0.0
58    #[error("Invalid frequency penalty: {0}. Must be >= 0.0")]
59    InvalidFrequencyPenalty(f32),
60
61    /// Invalid presence penalty - must be >= 0.0
62    #[error("Invalid presence penalty: {0}. Must be >= 0.0")]
63    InvalidPresencePenalty(f32),
64}
65
66impl ProcessorConfig {
67    /// Create a new configuration with default values
68    #[must_use]
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// Validate configuration values
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if any configuration parameter is invalid
78    pub fn validate(&self) -> Result<(), ConfigError> {
79        if self.temperature <= 0.0 {
80            return Err(ConfigError::InvalidTemperature(self.temperature));
81        }
82
83        if let Some(k) = self.top_k
84            && k == 0
85        {
86            return Err(ConfigError::InvalidTopK(k));
87        }
88
89        if let Some(top_p) = self.top_p
90            && !(0.0..=1.0).contains(&top_p)
91        {
92            return Err(ConfigError::InvalidTopP(top_p));
93        }
94
95        if self.repetition_penalty < 1.0 {
96            return Err(ConfigError::InvalidRepetitionPenalty(
97                self.repetition_penalty,
98            ));
99        }
100
101        if self.frequency_penalty < 0.0 {
102            return Err(ConfigError::InvalidFrequencyPenalty(self.frequency_penalty));
103        }
104
105        if self.presence_penalty < 0.0 {
106            return Err(ConfigError::InvalidPresencePenalty(self.presence_penalty));
107        }
108
109        Ok(())
110    }
111
112    /// Set temperature parameter
113    #[must_use]
114    pub const fn with_temperature(mut self, temperature: f32) -> Self {
115        self.temperature = temperature;
116        self
117    }
118
119    /// Set top-k parameter
120    #[must_use]
121    pub const fn with_top_k(mut self, top_k: Option<usize>) -> Self {
122        self.top_k = top_k;
123        self
124    }
125
126    /// Set top-p parameter
127    #[must_use]
128    pub const fn with_top_p(mut self, top_p: Option<f32>) -> Self {
129        self.top_p = top_p;
130        self
131    }
132
133    /// Set repetition penalty
134    #[must_use]
135    pub const fn with_repetition_penalty(mut self, penalty: f32) -> Self {
136        self.repetition_penalty = penalty;
137        self
138    }
139
140    /// Set frequency penalty
141    #[must_use]
142    pub const fn with_frequency_penalty(mut self, penalty: f32) -> Self {
143        self.frequency_penalty = penalty;
144        self
145    }
146
147    /// Set presence penalty
148    #[must_use]
149    pub const fn with_presence_penalty(mut self, penalty: f32) -> Self {
150        self.presence_penalty = penalty;
151        self
152    }
153}