Skip to main content

lash_core/
model.rs

1use std::num::NonZeroUsize;
2
3#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
4#[serde(deny_unknown_fields)]
5pub struct ModelSpec {
6    pub id: String,
7    #[serde(default, skip_serializing_if = "Option::is_none")]
8    pub variant: Option<String>,
9    pub limits: ModelLimits,
10}
11
12impl ModelSpec {
13    pub fn new(id: impl Into<String>, context_window_tokens: NonZeroUsize) -> Self {
14        Self {
15            id: id.into(),
16            variant: None,
17            limits: ModelLimits {
18                context_window_tokens,
19                input_token_capacity: None,
20                output_token_capacity: None,
21            },
22        }
23    }
24
25    pub fn with_limits(
26        id: impl Into<String>,
27        variant: Option<String>,
28        limits: ModelLimits,
29    ) -> Self {
30        Self {
31            id: id.into(),
32            variant,
33            limits,
34        }
35    }
36
37    pub fn with_variant(mut self, variant: Option<String>) -> Self {
38        self.variant = variant;
39        self
40    }
41
42    pub fn from_token_limits(
43        id: impl Into<String>,
44        variant: Option<String>,
45        context_window_tokens: usize,
46        input_token_capacity: Option<usize>,
47        output_token_capacity: Option<usize>,
48    ) -> Result<Self, String> {
49        Ok(Self::with_limits(
50            id,
51            variant,
52            ModelLimits::from_token_limits(
53                context_window_tokens,
54                input_token_capacity,
55                output_token_capacity,
56            )?,
57        ))
58    }
59
60    pub fn context_window_tokens(&self) -> usize {
61        self.limits.context_window_tokens.get()
62    }
63}
64
65impl Default for ModelSpec {
66    fn default() -> Self {
67        Self::new(
68            String::new(),
69            NonZeroUsize::new(1).expect("one is non-zero"),
70        )
71    }
72}
73
74#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
75#[serde(deny_unknown_fields)]
76pub struct ModelLimits {
77    pub context_window_tokens: NonZeroUsize,
78    #[serde(default, skip_serializing_if = "Option::is_none")]
79    pub input_token_capacity: Option<NonZeroUsize>,
80    #[serde(default, skip_serializing_if = "Option::is_none")]
81    pub output_token_capacity: Option<NonZeroUsize>,
82}
83
84impl ModelLimits {
85    pub fn from_token_limits(
86        context_window_tokens: usize,
87        input_token_capacity: Option<usize>,
88        output_token_capacity: Option<usize>,
89    ) -> Result<Self, String> {
90        Ok(Self {
91            context_window_tokens: nonzero_token_limit(
92                "context_window_tokens",
93                context_window_tokens,
94            )?,
95            input_token_capacity: optional_nonzero_token_limit(
96                "input_token_capacity",
97                input_token_capacity,
98            )?,
99            output_token_capacity: optional_nonzero_token_limit(
100                "output_token_capacity",
101                output_token_capacity,
102            )?,
103        })
104    }
105}
106
107impl Default for ModelLimits {
108    fn default() -> Self {
109        Self {
110            context_window_tokens: NonZeroUsize::new(1).expect("one is non-zero"),
111            input_token_capacity: None,
112            output_token_capacity: None,
113        }
114    }
115}
116
117fn nonzero_token_limit(name: &str, value: usize) -> Result<NonZeroUsize, String> {
118    NonZeroUsize::new(value).ok_or_else(|| format!("{name} must be greater than zero"))
119}
120
121fn optional_nonzero_token_limit(
122    name: &str,
123    value: Option<usize>,
124) -> Result<Option<NonZeroUsize>, String> {
125    value
126        .map(|value| nonzero_token_limit(name, value))
127        .transpose()
128}