Skip to main content

llama_cpp_bindings/model/
split_mode.rs

1/// A rusty wrapper around `llama_split_mode`.
2#[repr(i8)]
3#[derive(Copy, Clone, Debug, PartialEq, Eq)]
4pub enum LlamaSplitMode {
5    /// Single GPU
6    None = LLAMA_SPLIT_MODE_NONE,
7    /// Split layers and KV across GPUs
8    Layer = LLAMA_SPLIT_MODE_LAYER,
9    /// Split layers and KV across GPUs, use tensor parallelism if supported
10    Row = LLAMA_SPLIT_MODE_ROW,
11}
12
13const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_NONE as i8;
14const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_LAYER as i8;
15const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_ROW as i8;
16
17/// An error that occurs when unknown split mode is encountered.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct LlamaSplitModeParseError {
20    /// The value that could not be parsed as a split mode.
21    pub value: i32,
22    /// Additional context about why the parse failed.
23    pub context: String,
24}
25
26/// Create a `LlamaSplitMode` from a `i32`.
27///
28/// # Errors
29/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
30impl TryFrom<i32> for LlamaSplitMode {
31    type Error = LlamaSplitModeParseError;
32
33    fn try_from(value: i32) -> Result<Self, Self::Error> {
34        let i8_value = value
35            .try_into()
36            .map_err(|convert_error| LlamaSplitModeParseError {
37                value,
38                context: format!("i32 to i8 conversion failed: {convert_error}"),
39            })?;
40
41        match i8_value {
42            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
43            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
44            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
45            _ => Err(LlamaSplitModeParseError {
46                value,
47                context: format!("unknown split mode value: {value}"),
48            }),
49        }
50    }
51}
52
53/// Create a `LlamaSplitMode` from a `u32`.
54///
55/// # Errors
56/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
57impl TryFrom<u32> for LlamaSplitMode {
58    type Error = LlamaSplitModeParseError;
59
60    fn try_from(value: u32) -> Result<Self, Self::Error> {
61        let clamped_value = i32::try_from(value).unwrap_or(i32::MAX);
62        let i8_value = value
63            .try_into()
64            .map_err(|convert_error| LlamaSplitModeParseError {
65                value: clamped_value,
66                context: format!("u32 to i8 conversion failed: {convert_error}"),
67            })?;
68
69        match i8_value {
70            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
71            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
72            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
73            _ => Err(LlamaSplitModeParseError {
74                value: clamped_value,
75                context: format!("unknown split mode value: {value}"),
76            }),
77        }
78    }
79}
80
81/// Create a `i32` from a `LlamaSplitMode`.
82impl From<LlamaSplitMode> for i32 {
83    fn from(value: LlamaSplitMode) -> Self {
84        match value {
85            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
86            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
87            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
88        }
89    }
90}
91
92/// Create a `u32` from a `LlamaSplitMode`.
93impl From<LlamaSplitMode> for u32 {
94    fn from(value: LlamaSplitMode) -> Self {
95        match value {
96            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
97            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
98            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
99        }
100    }
101}
102
103/// The default split mode is `Layer` in llama.cpp.
104impl Default for LlamaSplitMode {
105    fn default() -> Self {
106        LlamaSplitMode::Layer
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::LlamaSplitMode;
113
114    #[test]
115    fn try_from_i32_invalid() {
116        let result = LlamaSplitMode::try_from(99_i32);
117
118        assert!(result.is_err());
119        let error = result.unwrap_err();
120        assert_eq!(error.value, 99);
121    }
122
123    #[test]
124    fn try_from_u32_invalid() {
125        assert!(LlamaSplitMode::try_from(99_u32).is_err());
126    }
127}