1use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6pub type Result<T> = std::result::Result<T, OcelotlError>;
7
8#[derive(Debug, Error)]
9pub enum OcelotlError {
10 #[error("invalid model artifact: {0}")]
11 InvalidModel(String),
12 #[error("unsupported feature: {0}")]
13 Unsupported(String),
14 #[error("runtime error: {0}")]
15 Runtime(String),
16}
17
18#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
19pub enum Device {
20 Cpu,
21 Gpu { ordinal: usize },
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
25pub enum DType {
26 F32,
27 F16,
28 BF16,
29 Q4,
30 Q8,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
34pub struct ModelInfo {
35 pub architecture: String,
36 pub parameter_count: Option<u64>,
37 pub context_length: usize,
38 pub dtype: DType,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
42pub struct GenerationOptions {
43 pub max_new_tokens: usize,
44 pub temperature: Option<u32>,
45}
46
47impl Default for GenerationOptions {
48 fn default() -> Self {
49 Self {
50 max_new_tokens: 256,
51 temperature: None,
52 }
53 }
54}