Skip to main content

llama_cpp_bindings/model/
split_mode.rs

1/// A rusty wrapper around `llama_split_mode`.
2#[repr(i8)]
3#[derive(Copy, Clone, Debug, PartialEq, Eq)]
4pub enum LlamaSplitMode {
5    /// Single GPU
6    None = LLAMA_SPLIT_MODE_NONE,
7    /// Split layers and KV across GPUs
8    Layer = LLAMA_SPLIT_MODE_LAYER,
9    /// Split layers and KV across GPUs, use tensor parallelism if supported
10    Row = LLAMA_SPLIT_MODE_ROW,
11}
12
13// These constants are known small u32 values from the C API that safely fit in i8.
14#[allow(clippy::cast_possible_truncation)]
15const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_NONE as i8;
16#[allow(clippy::cast_possible_truncation)]
17const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_LAYER as i8;
18#[allow(clippy::cast_possible_truncation)]
19const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_ROW as i8;
20
21/// An error that occurs when unknown split mode is encountered.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct LlamaSplitModeParseError {
24    /// The value that could not be parsed as a split mode.
25    pub value: i32,
26    /// Additional context about why the parse failed.
27    pub context: String,
28}
29
30/// Create a `LlamaSplitMode` from a `i32`.
31///
32/// # Errors
33/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
34impl TryFrom<i32> for LlamaSplitMode {
35    type Error = LlamaSplitModeParseError;
36
37    fn try_from(value: i32) -> Result<Self, Self::Error> {
38        let i8_value = value
39            .try_into()
40            .map_err(|convert_error| LlamaSplitModeParseError {
41                value,
42                context: format!("i32 to i8 conversion failed: {convert_error}"),
43            })?;
44
45        match i8_value {
46            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
47            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
48            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
49            _ => Err(LlamaSplitModeParseError {
50                value,
51                context: format!("unknown split mode value: {value}"),
52            }),
53        }
54    }
55}
56
57/// Create a `LlamaSplitMode` from a `u32`.
58///
59/// # Errors
60/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
61impl TryFrom<u32> for LlamaSplitMode {
62    type Error = LlamaSplitModeParseError;
63
64    fn try_from(value: u32) -> Result<Self, Self::Error> {
65        let clamped_value = i32::try_from(value).unwrap_or(i32::MAX);
66        let i8_value = value
67            .try_into()
68            .map_err(|convert_error| LlamaSplitModeParseError {
69                value: clamped_value,
70                context: format!("u32 to i8 conversion failed: {convert_error}"),
71            })?;
72
73        match i8_value {
74            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
75            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
76            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
77            _ => Err(LlamaSplitModeParseError {
78                value: clamped_value,
79                context: format!("unknown split mode value: {value}"),
80            }),
81        }
82    }
83}
84
85/// Create a `i32` from a `LlamaSplitMode`.
86impl From<LlamaSplitMode> for i32 {
87    fn from(value: LlamaSplitMode) -> Self {
88        match value {
89            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
90            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
91            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
92        }
93    }
94}
95
96/// Create a `u32` from a `LlamaSplitMode`.
97impl From<LlamaSplitMode> for u32 {
98    fn from(value: LlamaSplitMode) -> Self {
99        match value {
100            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as Self,
101            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as Self,
102            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as Self,
103        }
104    }
105}
106
107/// The default split mode is `Layer` in llama.cpp.
108impl Default for LlamaSplitMode {
109    fn default() -> Self {
110        Self::Layer
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::{
117        LLAMA_SPLIT_MODE_LAYER, LLAMA_SPLIT_MODE_NONE, LLAMA_SPLIT_MODE_ROW, LlamaSplitMode,
118    };
119
120    #[test]
121    fn try_from_i32_invalid() {
122        let result = LlamaSplitMode::try_from(99_i32);
123
124        assert!(result.is_err());
125        let error = result.unwrap_err();
126        assert_eq!(error.value, 99);
127    }
128
129    #[test]
130    fn try_from_u32_invalid() {
131        assert!(LlamaSplitMode::try_from(99_u32).is_err());
132    }
133
134    #[test]
135    fn try_from_i32_none_roundtrip() {
136        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_NONE)).unwrap();
137
138        assert_eq!(mode, LlamaSplitMode::None);
139        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_NONE));
140    }
141
142    #[test]
143    fn try_from_i32_layer_roundtrip() {
144        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_LAYER)).unwrap();
145
146        assert_eq!(mode, LlamaSplitMode::Layer);
147        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_LAYER));
148    }
149
150    #[test]
151    fn try_from_i32_row_roundtrip() {
152        let mode = LlamaSplitMode::try_from(i32::from(LLAMA_SPLIT_MODE_ROW)).unwrap();
153
154        assert_eq!(mode, LlamaSplitMode::Row);
155        assert_eq!(i32::from(mode), i32::from(LLAMA_SPLIT_MODE_ROW));
156    }
157
158    #[test]
159    fn try_from_u32_none_roundtrip() {
160        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_NONE as u32).unwrap();
161
162        assert_eq!(mode, LlamaSplitMode::None);
163        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_NONE as u32);
164    }
165
166    #[test]
167    fn try_from_u32_layer_roundtrip() {
168        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_LAYER as u32).unwrap();
169
170        assert_eq!(mode, LlamaSplitMode::Layer);
171        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_LAYER as u32);
172    }
173
174    #[test]
175    fn try_from_u32_row_roundtrip() {
176        let mode = LlamaSplitMode::try_from(LLAMA_SPLIT_MODE_ROW as u32).unwrap();
177
178        assert_eq!(mode, LlamaSplitMode::Row);
179        assert_eq!(u32::from(mode), LLAMA_SPLIT_MODE_ROW as u32);
180    }
181
182    #[test]
183    fn default_is_layer() {
184        assert_eq!(LlamaSplitMode::default(), LlamaSplitMode::Layer);
185    }
186
187    #[test]
188    fn try_from_i32_overflow_returns_error() {
189        let result = LlamaSplitMode::try_from(i32::MAX);
190
191        assert!(result.is_err());
192        assert!(
193            result
194                .unwrap_err()
195                .context
196                .contains("i32 to i8 conversion failed")
197        );
198    }
199
200    #[test]
201    fn try_from_u32_overflow_returns_error() {
202        let result = LlamaSplitMode::try_from(u32::MAX);
203
204        assert!(result.is_err());
205        assert!(
206            result
207                .unwrap_err()
208                .context
209                .contains("u32 to i8 conversion failed")
210        );
211    }
212}