Skip to main content

llama_cpp_bindings/model/
split_mode.rs

1use crate::model::llama_split_mode_parse_error::LlamaSplitModeParseError;
2
3/// A rusty wrapper around `llama_split_mode`.
4#[repr(i8)]
5#[derive(Copy, Clone, Debug, PartialEq, Eq)]
6pub enum LlamaSplitMode {
7    /// Single GPU
8    None = LLAMA_SPLIT_MODE_NONE,
9    /// Split layers and KV across GPUs
10    Layer = LLAMA_SPLIT_MODE_LAYER,
11    /// Split layers and KV across GPUs, use tensor parallelism if supported
12    Row = LLAMA_SPLIT_MODE_ROW,
13    /// Experimental tensor parallelism across GPUs
14    Tensor = LLAMA_SPLIT_MODE_TENSOR,
15}
16
17#[expect(
18    clippy::cast_possible_truncation,
19    reason = "the C API split mode constants are known small values that fit in i8"
20)]
21const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_NONE as i8;
22#[expect(
23    clippy::cast_possible_truncation,
24    reason = "the C API split mode constants are known small values that fit in i8"
25)]
26const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_LAYER as i8;
27#[expect(
28    clippy::cast_possible_truncation,
29    reason = "the C API split mode constants are known small values that fit in i8"
30)]
31const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_ROW as i8;
32#[expect(
33    clippy::cast_possible_truncation,
34    reason = "the C API split mode constants are known small values that fit in i8"
35)]
36const LLAMA_SPLIT_MODE_TENSOR: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_TENSOR as i8;
37
38/// Create a `LlamaSplitMode` from a `i32`.
39///
40/// # Errors
41/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
42impl TryFrom<i32> for LlamaSplitMode {
43    type Error = LlamaSplitModeParseError;
44
45    fn try_from(value: i32) -> Result<Self, Self::Error> {
46        let i8_value = value
47            .try_into()
48            .map_err(|convert_error| LlamaSplitModeParseError {
49                value,
50                context: format!("i32 to i8 conversion failed: {convert_error}"),
51            })?;
52
53        match i8_value {
54            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
55            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
56            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
57            LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
58            _ => Err(LlamaSplitModeParseError {
59                value,
60                context: format!("unknown split mode value: {value}"),
61            }),
62        }
63    }
64}
65
66/// Create a `LlamaSplitMode` from a `u32`.
67///
68/// # Errors
69/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
70impl TryFrom<u32> for LlamaSplitMode {
71    type Error = LlamaSplitModeParseError;
72
73    fn try_from(value: u32) -> Result<Self, Self::Error> {
74        let clamped_value = i32::try_from(value).unwrap_or(i32::MAX);
75        let i8_value = value
76            .try_into()
77            .map_err(|convert_error| LlamaSplitModeParseError {
78                value: clamped_value,
79                context: format!("u32 to i8 conversion failed: {convert_error}"),
80            })?;
81
82        match i8_value {
83            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
84            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
85            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
86            LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
87            _ => Err(LlamaSplitModeParseError {
88                value: clamped_value,
89                context: format!("unknown split mode value: {value}"),
90            }),
91        }
92    }
93}
94
95/// Create a `i32` from a `LlamaSplitMode`.
96impl From<LlamaSplitMode> for i32 {
97    fn from(value: LlamaSplitMode) -> Self {
98        match value {
99            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
100            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
101            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
102            LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR.into(),
103        }
104    }
105}
106
107/// Create a `u32` from a `LlamaSplitMode`.
108impl From<LlamaSplitMode> for u32 {
109    fn from(value: LlamaSplitMode) -> Self {
110        match value {
111            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as Self,
112            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as Self,
113            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as Self,
114            LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR as Self,
115        }
116    }
117}
118
119/// The default split mode is `Layer` in llama.cpp.
120impl Default for LlamaSplitMode {
121    fn default() -> Self {
122        Self::Layer
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::{
129        LLAMA_SPLIT_MODE_LAYER, LLAMA_SPLIT_MODE_NONE, LLAMA_SPLIT_MODE_ROW,
130        LLAMA_SPLIT_MODE_TENSOR, LlamaSplitMode,
131    };
132
133    #[test]
134    fn try_from_i32_invalid() {
135        let result = LlamaSplitMode::try_from(99_i32);
136
137        assert!(result.is_err());
138        let error = result.unwrap_err();
139        assert_eq!(error.value, 99);
140    }
141
142    #[test]
143    fn try_from_u32_invalid() {
144        assert!(LlamaSplitMode::try_from(99_u32).is_err());
145    }
146
147    #[test]
148    fn try_from_i32_none_roundtrip() {
149        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_NONE)).unwrap();
150
151        assert_eq!(mode, LlamaSplitMode::None);
152        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_NONE));
153    }
154
155    #[test]
156    fn try_from_i32_layer_roundtrip() {
157        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_LAYER)).unwrap();
158
159        assert_eq!(mode, LlamaSplitMode::Layer);
160        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_LAYER));
161    }
162
163    #[test]
164    fn try_from_i32_row_roundtrip() {
165        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_ROW)).unwrap();
166
167        assert_eq!(mode, LlamaSplitMode::Row);
168        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_ROW));
169    }
170
171    #[test]
172    fn try_from_i32_tensor_roundtrip() {
173        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_TENSOR)).unwrap();
174
175        assert_eq!(mode, LlamaSplitMode::Tensor);
176        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_TENSOR));
177    }
178
179    #[test]
180    fn try_from_u32_none_roundtrip() {
181        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_NONE as u32).unwrap();
182
183        assert_eq!(mode, LlamaSplitMode::None);
184        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_NONE as u32);
185    }
186
187    #[test]
188    fn try_from_u32_layer_roundtrip() {
189        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_LAYER as u32).unwrap();
190
191        assert_eq!(mode, LlamaSplitMode::Layer);
192        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_LAYER as u32);
193    }
194
195    #[test]
196    fn try_from_u32_row_roundtrip() {
197        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_ROW as u32).unwrap();
198
199        assert_eq!(mode, LlamaSplitMode::Row);
200        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_ROW as u32);
201    }
202
203    #[test]
204    fn try_from_u32_tensor_roundtrip() {
205        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_TENSOR as u32).unwrap();
206
207        assert_eq!(mode, LlamaSplitMode::Tensor);
208        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_TENSOR as u32);
209    }
210
211    #[test]
212    fn default_is_layer() {
213        assert_eq!(LlamaSplitMode::default(), LlamaSplitMode::Layer);
214    }
215
216    #[test]
217    fn try_from_i32_overflow_returns_error() {
218        let result = LlamaSplitMode::try_from(i32::MAX);
219
220        assert!(result.is_err());
221        assert!(
222            result
223                .unwrap_err()
224                .context
225                .contains("i32 to i8 conversion failed")
226        );
227    }
228
229    #[test]
230    fn try_from_u32_overflow_returns_error() {
231        let result = LlamaSplitMode::try_from(u32::MAX);
232
233        assert!(result.is_err());
234        assert!(
235            result
236                .unwrap_err()
237                .context
238                .contains("u32 to i8 conversion failed")
239        );
240    }
241}