Skip to main content

llama_cpp_4/rpc/
backend.rs

1//! RPC backend for distributed inference
2
3use crate::rpc::error::RpcError;
4use llama_cpp_sys_4 as sys;
5use std::ffi::CString;
6use std::ptr::NonNull;
7
8/// RPC backend for distributed inference across multiple machines
9pub struct RpcBackend {
10    backend: NonNull<sys::ggml_backend>,
11    endpoint: String,
12}
13
14impl RpcBackend {
15    /// Initialize a new RPC backend for the given endpoint
16    ///
17    /// # Arguments
18    /// * `endpoint` - The RPC server endpoint (e.g., "127.0.0.1:50052")
19    ///
20    /// # Example
21    /// ```no_run
22    /// use llama_cpp_4::rpc::RpcBackend;
23    ///
24    /// let backend = RpcBackend::init("127.0.0.1:50052")?;
25    /// ```
26    pub fn init(endpoint: &str) -> Result<Self, RpcError> {
27        let c_endpoint = CString::new(endpoint).map_err(|e| RpcError::StringConversion(e))?;
28
29        let backend = unsafe { sys::ggml_backend_rpc_init(c_endpoint.as_ptr()) };
30
31        NonNull::new(backend)
32            .map(|ptr| Self {
33                backend: ptr,
34                endpoint: endpoint.to_string(),
35            })
36            .ok_or_else(|| RpcError::InitializationFailed {
37                endpoint: endpoint.to_string(),
38            })
39    }
40
41    /// Check if a backend is an RPC backend
42    pub fn is_rpc(&self) -> bool {
43        unsafe { sys::ggml_backend_is_rpc(self.backend.as_ptr()) }
44    }
45
46    /// Get the buffer type for this RPC backend
47    pub fn buffer_type(&self) -> Option<NonNull<sys::ggml_backend_buffer_type>> {
48        let c_endpoint = CString::new(self.endpoint.as_str()).ok()?;
49        let buffer_type = unsafe { sys::ggml_backend_rpc_buffer_type(c_endpoint.as_ptr()) };
50        NonNull::new(buffer_type)
51    }
52
53    /// Query the available memory on the remote device
54    ///
55    /// Returns (free_memory, total_memory) in bytes
56    pub fn get_device_memory(&self) -> Result<(usize, usize), RpcError> {
57        let c_endpoint =
58            CString::new(self.endpoint.as_str()).map_err(|e| RpcError::StringConversion(e))?;
59
60        let mut free: usize = 0;
61        let mut total: usize = 0;
62
63        unsafe {
64            sys::ggml_backend_rpc_get_device_memory(c_endpoint.as_ptr(), &mut free, &mut total);
65        }
66
67        if total == 0 {
68            Err(RpcError::MemoryQueryFailed)
69        } else {
70            Ok((free, total))
71        }
72    }
73
74    /// Get the endpoint this backend is connected to
75    pub fn endpoint(&self) -> &str {
76        &self.endpoint
77    }
78
79    /// Get the raw backend pointer for FFI calls
80    pub(crate) fn as_ptr(&self) -> NonNull<sys::ggml_backend> {
81        self.backend
82    }
83}
84
85impl Drop for RpcBackend {
86    fn drop(&mut self) {
87        unsafe {
88            sys::ggml_backend_free(self.backend.as_ptr());
89        }
90    }
91}
92
93// Safety: RpcBackend can be sent between threads
94unsafe impl Send for RpcBackend {}
95// Safety: RpcBackend can be shared between threads (the C API is thread-safe)
96unsafe impl Sync for RpcBackend {}
97
98impl std::fmt::Debug for RpcBackend {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct("RpcBackend")
101            .field("endpoint", &self.endpoint)
102            .field("is_rpc", &self.is_rpc())
103            .finish()
104    }
105}