Skip to main content

llama_cpp_bindings/model/
split_mode.rs

1use crate::model::llama_split_mode_parse_error::LlamaSplitModeParseError;
2
3#[repr(i8)]
4#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
5pub enum LlamaSplitMode {
6    None = LLAMA_SPLIT_MODE_NONE,
7    #[default]
8    Layer = LLAMA_SPLIT_MODE_LAYER,
9    Row = LLAMA_SPLIT_MODE_ROW,
10    Tensor = LLAMA_SPLIT_MODE_TENSOR,
11}
12
13#[expect(
14    clippy::cast_possible_truncation,
15    reason = "the C API split mode constants are known small values that fit in i8"
16)]
17const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_NONE as i8;
18#[expect(
19    clippy::cast_possible_truncation,
20    reason = "the C API split mode constants are known small values that fit in i8"
21)]
22const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_LAYER as i8;
23#[expect(
24    clippy::cast_possible_truncation,
25    reason = "the C API split mode constants are known small values that fit in i8"
26)]
27const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_ROW as i8;
28#[expect(
29    clippy::cast_possible_truncation,
30    reason = "the C API split mode constants are known small values that fit in i8"
31)]
32const LLAMA_SPLIT_MODE_TENSOR: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_TENSOR as i8;
33
34/// # Errors
35/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
36impl TryFrom<i32> for LlamaSplitMode {
37    type Error = LlamaSplitModeParseError;
38
39    fn try_from(value: i32) -> Result<Self, Self::Error> {
40        let i8_value = value
41            .try_into()
42            .map_err(|convert_error| LlamaSplitModeParseError {
43                value,
44                context: format!("i32 to i8 conversion failed: {convert_error}"),
45            })?;
46
47        match i8_value {
48            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
49            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
50            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
51            LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
52            _ => Err(LlamaSplitModeParseError {
53                value,
54                context: format!("unknown split mode value: {value}"),
55            }),
56        }
57    }
58}
59
60/// # Errors
61/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
62impl TryFrom<u32> for LlamaSplitMode {
63    type Error = LlamaSplitModeParseError;
64
65    fn try_from(value: u32) -> Result<Self, Self::Error> {
66        let clamped_value = i32::try_from(value).unwrap_or(i32::MAX);
67        let i8_value = value
68            .try_into()
69            .map_err(|convert_error| LlamaSplitModeParseError {
70                value: clamped_value,
71                context: format!("u32 to i8 conversion failed: {convert_error}"),
72            })?;
73
74        match i8_value {
75            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
76            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
77            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
78            LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
79            _ => Err(LlamaSplitModeParseError {
80                value: clamped_value,
81                context: format!("unknown split mode value: {value}"),
82            }),
83        }
84    }
85}
86
87impl From<LlamaSplitMode> for i32 {
88    fn from(value: LlamaSplitMode) -> Self {
89        match value {
90            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
91            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
92            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
93            LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR.into(),
94        }
95    }
96}
97
98impl From<LlamaSplitMode> for u32 {
99    fn from(value: LlamaSplitMode) -> Self {
100        match value {
101            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as Self,
102            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as Self,
103            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as Self,
104            LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR as Self,
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::{
112        LLAMA_SPLIT_MODE_LAYER, LLAMA_SPLIT_MODE_NONE, LLAMA_SPLIT_MODE_ROW,
113        LLAMA_SPLIT_MODE_TENSOR, LlamaSplitMode,
114    };
115
116    #[test]
117    fn try_from_i32_invalid() {
118        let result = LlamaSplitMode::try_from(99_i32);
119
120        assert!(result.is_err());
121        let error = result.unwrap_err();
122        assert_eq!(error.value, 99);
123    }
124
125    #[test]
126    fn try_from_u32_invalid() {
127        assert!(LlamaSplitMode::try_from(99_u32).is_err());
128    }
129
130    #[test]
131    fn try_from_i32_none_roundtrip() {
132        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_NONE)).unwrap();
133
134        assert_eq!(mode, LlamaSplitMode::None);
135        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_NONE));
136    }
137
138    #[test]
139    fn try_from_i32_layer_roundtrip() {
140        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_LAYER)).unwrap();
141
142        assert_eq!(mode, LlamaSplitMode::Layer);
143        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_LAYER));
144    }
145
146    #[test]
147    fn try_from_i32_row_roundtrip() {
148        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_ROW)).unwrap();
149
150        assert_eq!(mode, LlamaSplitMode::Row);
151        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_ROW));
152    }
153
154    #[test]
155    fn try_from_i32_tensor_roundtrip() {
156        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_TENSOR)).unwrap();
157
158        assert_eq!(mode, LlamaSplitMode::Tensor);
159        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_TENSOR));
160    }
161
162    #[test]
163    fn try_from_u32_none_roundtrip() {
164        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_NONE as u32).unwrap();
165
166        assert_eq!(mode, LlamaSplitMode::None);
167        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_NONE as u32);
168    }
169
170    #[test]
171    fn try_from_u32_layer_roundtrip() {
172        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_LAYER as u32).unwrap();
173
174        assert_eq!(mode, LlamaSplitMode::Layer);
175        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_LAYER as u32);
176    }
177
178    #[test]
179    fn try_from_u32_row_roundtrip() {
180        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_ROW as u32).unwrap();
181
182        assert_eq!(mode, LlamaSplitMode::Row);
183        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_ROW as u32);
184    }
185
186    #[test]
187    fn try_from_u32_tensor_roundtrip() {
188        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_TENSOR as u32).unwrap();
189
190        assert_eq!(mode, LlamaSplitMode::Tensor);
191        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_TENSOR as u32);
192    }
193
194    #[test]
195    fn default_is_layer() {
196        assert_eq!(LlamaSplitMode::default(), LlamaSplitMode::Layer);
197    }
198
199    #[test]
200    fn try_from_i32_overflow_returns_error() {
201        let result = LlamaSplitMode::try_from(i32::MAX);
202
203        assert!(result.is_err());
204        assert!(
205            result
206                .unwrap_err()
207                .context
208                .contains("i32 to i8 conversion failed")
209        );
210    }
211
212    #[test]
213    fn try_from_u32_overflow_returns_error() {
214        let result = LlamaSplitMode::try_from(u32::MAX);
215
216        assert!(result.is_err());
217        assert!(
218            result
219                .unwrap_err()
220                .context
221                .contains("u32 to i8 conversion failed")
222        );
223    }
224}