1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
//! Token counting functionality for the Ollama API client
//!
//! This module provides comprehensive token counting capabilities including:
//! - Token estimation for text inputs before API calls
//! - Cost calculation based on token counts
//! - Input validation using token limits
//! - Batch operation optimization based on token counts
//!
//! All functionality follows the "Thin Client, Rich API" governing principle,
//! providing explicit control with transparent API mapping to Ollama endpoints.
use serde::{ Serialize, Deserialize };
/// Request structure for counting tokens in text input
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TokenCountRequest
{
/// Model name to use for token counting (e.g., "llama3.2")
pub model : String,
/// Text to count tokens for
pub text : String,
/// Additional tokenization options
pub options : Option< serde_json::Value >,
}
/// Response structure for token counting results
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TokenCountResponse
{
/// Number of tokens in the input text
pub token_count : u32,
/// Model used for token counting
pub model : String,
/// Length of input text in characters
pub text_length : usize,
/// Estimated cost for processing this many tokens
pub estimated_cost : Option< f64 >,
/// Time taken to count tokens in milliseconds
pub processing_time_ms : Option< u64 >,
/// Additional metadata from token counting
pub metadata : Option< serde_json::Value >,
}
/// Cost estimation structure based on token counts
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct CostEstimation
{
/// Number of input tokens
pub input_tokens : u32,
/// Estimated number of output tokens
pub estimated_output_tokens : u32,
/// Cost per input token
pub input_cost_per_token : f64,
/// Cost per output token
pub output_cost_per_token : f64,
/// Total estimated cost for the operation
pub total_estimated_cost : f64,
/// Currency for cost calculation
pub currency : String,
/// Model name for cost calculation
pub model : String,
}
/// Request structure for batch token counting
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct BatchTokenRequest
{
/// Model name to use for all token counting
pub model : String,
/// List of texts to count tokens for
pub texts : Vec< String >,
/// Additional tokenization options
pub options : Option< serde_json::Value >,
/// Whether to include cost estimation in results
pub estimate_costs : bool,
}
/// Response structure for batch token counting
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct BatchTokenResponse
{
/// Individual token count results for each text
pub results : Vec< TokenCountResponse >,
/// Total tokens across all texts
pub total_tokens : u32,
/// Total estimated cost for all texts
pub total_estimated_cost : Option< f64 >,
/// Total processing time in milliseconds
pub processing_time_ms : Option< u64 >,
/// Optimization savings from batch processing (percentage)
pub batch_optimization_savings : Option< f64 >,
}
/// Configuration for token validation and limits
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct TokenValidationConfig
{
/// Maximum allowed input tokens
pub max_input_tokens : u32,
/// Maximum allowed output tokens
pub max_output_tokens : u32,
/// Model's context window size
pub model_context_window : u32,
/// Warning threshold as percentage of limit (0.0 to 1.0)
pub warning_threshold : f64,
/// Whether to enforce limits strictly
pub enforce_limits : bool,
/// Strategy for text truncation : "start", "end", "middle"
pub truncation_strategy : String,
}
/// Model-specific token counting capabilities and costs
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct ModelTokenCapabilities
{
/// Model name
pub model_name : String,
/// Context window size in tokens
pub context_window : u32,
/// Whether model supports function calling
pub supports_function_calling : bool,
/// Average tokens per word for this model
pub average_tokens_per_word : f64,
/// Maximum input tokens for this model
pub max_input_tokens : u32,
/// Maximum output tokens for this model
pub max_output_tokens : u32,
/// Cost per input token
pub cost_per_input_token : f64,
/// Cost per output token
pub cost_per_output_token : f64,
/// Tokenizer type used by this model
pub tokenizer_type : String,
}
impl TokenCountRequest
{
/// Create a new token count request
#[ inline ]
#[ must_use ]
pub fn new( model : String, text : String ) -> Self
{
Self
{
model,
text,
options : None,
}
}
/// Create a token count request with options
#[ inline ]
#[ must_use ]
pub fn with_options( mut self, options : serde_json::Value ) -> Self
{
self.options = Some( options );
self
}
/// Get the estimated token count using simple heuristics
/// This is a rough estimate : typically 1 token per 4 characters for English text
#[ inline ]
#[ must_use ]
pub fn estimate_tokens( &self ) -> u32
{
// Simple heuristic : 1 token per 4 characters, minimum 1 token
( self.text.len() / 4 ).max( 1 ) as u32
}
}
impl CostEstimation
{
/// Create a new cost estimation
#[ inline ]
#[ must_use ]
pub fn new(
input_tokens : u32,
estimated_output_tokens : u32,
input_cost_per_token : f64,
output_cost_per_token : f64,
model : String,
) -> Self
{
let total_estimated_cost = ( input_tokens as f64 * input_cost_per_token ) +
( estimated_output_tokens as f64 * output_cost_per_token );
Self
{
input_tokens,
estimated_output_tokens,
input_cost_per_token,
output_cost_per_token,
total_estimated_cost,
currency : "USD".to_string(),
model,
}
}
/// Set the currency for cost calculation
#[ inline ]
#[ must_use ]
pub fn with_currency( mut self, currency : String ) -> Self
{
self.currency = currency;
self
}
/// Calculate cost savings percentage compared to another estimation
#[ inline ]
#[ must_use ]
pub fn calculate_savings( &self, other : &CostEstimation ) -> f64
{
if other.total_estimated_cost == 0.0
{
return 0.0;
}
( ( other.total_estimated_cost - self.total_estimated_cost ) / other.total_estimated_cost ) * 100.0
}
}
impl TokenValidationConfig
{
/// Create a new token validation configuration with defaults
#[ inline ]
#[ must_use ]
pub fn new( max_input_tokens : u32, max_output_tokens : u32, model_context_window : u32 ) -> Self
{
Self
{
max_input_tokens,
max_output_tokens,
model_context_window,
warning_threshold : 0.8,
enforce_limits : true,
truncation_strategy : "end".to_string(),
}
}
/// Set the warning threshold
#[ inline ]
#[ must_use ]
pub fn with_warning_threshold( mut self, threshold : f64 ) -> Self
{
self.warning_threshold = threshold.clamp( 0.0, 1.0 );
self
}
/// Set whether to enforce limits
#[ inline ]
#[ must_use ]
pub fn with_enforcement( mut self, enforce : bool ) -> Self
{
self.enforce_limits = enforce;
self
}
/// Set the truncation strategy
#[ inline ]
#[ must_use ]
pub fn with_truncation_strategy( mut self, strategy : String ) -> Self
{
self.truncation_strategy = strategy;
self
}
/// Check if token count exceeds warning threshold
#[ inline ]
#[ must_use ]
pub fn exceeds_warning_threshold( &self, token_count : u32 ) -> bool
{
token_count as f64 > ( self.max_input_tokens as f64 * self.warning_threshold )
}
/// Check if token count exceeds maximum limit
#[ inline ]
#[ must_use ]
pub fn exceeds_limit( &self, token_count : u32 ) -> bool
{
token_count > self.max_input_tokens
}
}
impl ModelTokenCapabilities
{
/// Create model capabilities for a standard chat model
#[ inline ]
#[ must_use ]
pub fn chat_model( model_name : String, context_window : u32 ) -> Self
{
Self
{
model_name,
context_window,
supports_function_calling : true,
average_tokens_per_word : 1.3,
max_input_tokens : ( context_window as f64 * 0.75 ) as u32, // 75% for input
max_output_tokens : ( context_window as f64 * 0.25 ) as u32, // 25% for output
cost_per_input_token : 0.0001,
cost_per_output_token : 0.0002,
tokenizer_type : "tiktoken".to_string(),
}
}
/// Create model capabilities for a code model
#[ inline ]
#[ must_use ]
pub fn code_model( model_name : String, context_window : u32 ) -> Self
{
Self
{
model_name,
context_window,
supports_function_calling : false,
average_tokens_per_word : 1.5, // Code typically has more tokens per word
max_input_tokens : ( context_window as f64 * 0.8 ) as u32, // 80% for input
max_output_tokens : ( context_window as f64 * 0.2 ) as u32, // 20% for output
cost_per_input_token : 0.00015,
cost_per_output_token : 0.0003,
tokenizer_type : "sentencepiece".to_string(),
}
}
/// Estimate tokens for given text using model-specific average
#[ inline ]
#[ must_use ]
pub fn estimate_tokens( &self, text : &str ) -> u32
{
let word_count = text.split_whitespace().count() as f64;
( word_count * self.average_tokens_per_word ).ceil() as u32
}
/// Calculate cost for given token counts
#[ inline ]
#[ must_use ]
pub fn calculate_cost( &self, input_tokens : u32, output_tokens : u32 ) -> f64
{
( input_tokens as f64 * self.cost_per_input_token ) +
( output_tokens as f64 * self.cost_per_output_token )
}
}