Skip to main content

llama_cpp_bindings/model/
split_mode.rs

1/// A rusty wrapper around `llama_split_mode`.
2#[repr(i8)]
3#[derive(Copy, Clone, Debug, PartialEq, Eq)]
4pub enum LlamaSplitMode {
5    /// Single GPU
6    None = LLAMA_SPLIT_MODE_NONE,
7    /// Split layers and KV across GPUs
8    Layer = LLAMA_SPLIT_MODE_LAYER,
9    /// Split layers and KV across GPUs, use tensor parallelism if supported
10    Row = LLAMA_SPLIT_MODE_ROW,
11    /// Experimental tensor parallelism across GPUs
12    Tensor = LLAMA_SPLIT_MODE_TENSOR,
13}
14
15#[expect(
16    clippy::cast_possible_truncation,
17    reason = "the C API split mode constants are known small values that fit in i8"
18)]
19const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_NONE as i8;
20#[expect(
21    clippy::cast_possible_truncation,
22    reason = "the C API split mode constants are known small values that fit in i8"
23)]
24const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_LAYER as i8;
25#[expect(
26    clippy::cast_possible_truncation,
27    reason = "the C API split mode constants are known small values that fit in i8"
28)]
29const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_ROW as i8;
30#[expect(
31    clippy::cast_possible_truncation,
32    reason = "the C API split mode constants are known small values that fit in i8"
33)]
34const LLAMA_SPLIT_MODE_TENSOR: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_TENSOR as i8;
35
36/// An error that occurs when unknown split mode is encountered.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct LlamaSplitModeParseError {
39    /// The value that could not be parsed as a split mode.
40    pub value: i32,
41    /// Additional context about why the parse failed.
42    pub context: String,
43}
44
45/// Create a `LlamaSplitMode` from a `i32`.
46///
47/// # Errors
48/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
49impl TryFrom<i32> for LlamaSplitMode {
50    type Error = LlamaSplitModeParseError;
51
52    fn try_from(value: i32) -> Result<Self, Self::Error> {
53        let i8_value = value
54            .try_into()
55            .map_err(|convert_error| LlamaSplitModeParseError {
56                value,
57                context: format!("i32 to i8 conversion failed: {convert_error}"),
58            })?;
59
60        match i8_value {
61            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
62            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
63            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
64            LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
65            _ => Err(LlamaSplitModeParseError {
66                value,
67                context: format!("unknown split mode value: {value}"),
68            }),
69        }
70    }
71}
72
73/// Create a `LlamaSplitMode` from a `u32`.
74///
75/// # Errors
76/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
77impl TryFrom<u32> for LlamaSplitMode {
78    type Error = LlamaSplitModeParseError;
79
80    fn try_from(value: u32) -> Result<Self, Self::Error> {
81        let clamped_value = i32::try_from(value).unwrap_or(i32::MAX);
82        let i8_value = value
83            .try_into()
84            .map_err(|convert_error| LlamaSplitModeParseError {
85                value: clamped_value,
86                context: format!("u32 to i8 conversion failed: {convert_error}"),
87            })?;
88
89        match i8_value {
90            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
91            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
92            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
93            LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
94            _ => Err(LlamaSplitModeParseError {
95                value: clamped_value,
96                context: format!("unknown split mode value: {value}"),
97            }),
98        }
99    }
100}
101
102/// Create a `i32` from a `LlamaSplitMode`.
103impl From<LlamaSplitMode> for i32 {
104    fn from(value: LlamaSplitMode) -> Self {
105        match value {
106            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
107            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
108            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
109            LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR.into(),
110        }
111    }
112}
113
114/// Create a `u32` from a `LlamaSplitMode`.
115impl From<LlamaSplitMode> for u32 {
116    fn from(value: LlamaSplitMode) -> Self {
117        match value {
118            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as Self,
119            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as Self,
120            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as Self,
121            LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR as Self,
122        }
123    }
124}
125
126/// The default split mode is `Layer` in llama.cpp.
127impl Default for LlamaSplitMode {
128    fn default() -> Self {
129        Self::Layer
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::{
136        LLAMA_SPLIT_MODE_LAYER, LLAMA_SPLIT_MODE_NONE, LLAMA_SPLIT_MODE_ROW,
137        LLAMA_SPLIT_MODE_TENSOR, LlamaSplitMode,
138    };
139
140    #[test]
141    fn try_from_i32_invalid() {
142        let result = LlamaSplitMode::try_from(99_i32);
143
144        assert!(result.is_err());
145        let error = result.unwrap_err();
146        assert_eq!(error.value, 99);
147    }
148
149    #[test]
150    fn try_from_u32_invalid() {
151        assert!(LlamaSplitMode::try_from(99_u32).is_err());
152    }
153
154    #[test]
155    fn try_from_i32_none_roundtrip() {
156        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_NONE)).unwrap();
157
158        assert_eq!(mode, LlamaSplitMode::None);
159        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_NONE));
160    }
161
162    #[test]
163    fn try_from_i32_layer_roundtrip() {
164        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_LAYER)).unwrap();
165
166        assert_eq!(mode, LlamaSplitMode::Layer);
167        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_LAYER));
168    }
169
170    #[test]
171    fn try_from_i32_row_roundtrip() {
172        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_ROW)).unwrap();
173
174        assert_eq!(mode, LlamaSplitMode::Row);
175        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_ROW));
176    }
177
178    #[test]
179    fn try_from_i32_tensor_roundtrip() {
180        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_TENSOR)).unwrap();
181
182        assert_eq!(mode, LlamaSplitMode::Tensor);
183        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_TENSOR));
184    }
185
186    #[test]
187    fn try_from_u32_none_roundtrip() {
188        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_NONE as u32).unwrap();
189
190        assert_eq!(mode, LlamaSplitMode::None);
191        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_NONE as u32);
192    }
193
194    #[test]
195    fn try_from_u32_layer_roundtrip() {
196        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_LAYER as u32).unwrap();
197
198        assert_eq!(mode, LlamaSplitMode::Layer);
199        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_LAYER as u32);
200    }
201
202    #[test]
203    fn try_from_u32_row_roundtrip() {
204        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_ROW as u32).unwrap();
205
206        assert_eq!(mode, LlamaSplitMode::Row);
207        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_ROW as u32);
208    }
209
210    #[test]
211    fn try_from_u32_tensor_roundtrip() {
212        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_TENSOR as u32).unwrap();
213
214        assert_eq!(mode, LlamaSplitMode::Tensor);
215        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_TENSOR as u32);
216    }
217
218    #[test]
219    fn default_is_layer() {
220        assert_eq!(LlamaSplitMode::default(), LlamaSplitMode::Layer);
221    }
222
223    #[test]
224    fn try_from_i32_overflow_returns_error() {
225        let result = LlamaSplitMode::try_from(i32::MAX);
226
227        assert!(result.is_err());
228        assert!(
229            result
230                .unwrap_err()
231                .context
232                .contains("i32 to i8 conversion failed")
233        );
234    }
235
236    #[test]
237    fn try_from_u32_overflow_returns_error() {
238        let result = LlamaSplitMode::try_from(u32::MAX);
239
240        assert!(result.is_err());
241        assert!(
242            result
243                .unwrap_err()
244                .context
245                .contains("u32 to i8 conversion failed")
246        );
247    }
248}