llama_cpp_4/rpc/
backend.rs1use crate::rpc::error::RpcError;
4use llama_cpp_sys_4 as sys;
5use std::ffi::CString;
6use std::ptr::NonNull;
7
8pub struct RpcBackend {
10 backend: NonNull<sys::ggml_backend>,
11 endpoint: String,
12}
13
14impl RpcBackend {
15 pub fn init(endpoint: &str) -> Result<Self, RpcError> {
27 let c_endpoint = CString::new(endpoint).map_err(|e| RpcError::StringConversion(e))?;
28
29 let backend = unsafe { sys::ggml_backend_rpc_init(c_endpoint.as_ptr()) };
30
31 NonNull::new(backend)
32 .map(|ptr| Self {
33 backend: ptr,
34 endpoint: endpoint.to_string(),
35 })
36 .ok_or_else(|| RpcError::InitializationFailed {
37 endpoint: endpoint.to_string(),
38 })
39 }
40
41 pub fn is_rpc(&self) -> bool {
43 unsafe { sys::ggml_backend_is_rpc(self.backend.as_ptr()) }
44 }
45
46 pub fn buffer_type(&self) -> Option<NonNull<sys::ggml_backend_buffer_type>> {
48 let c_endpoint = CString::new(self.endpoint.as_str()).ok()?;
49 let buffer_type = unsafe { sys::ggml_backend_rpc_buffer_type(c_endpoint.as_ptr()) };
50 NonNull::new(buffer_type)
51 }
52
53 pub fn get_device_memory(&self) -> Result<(usize, usize), RpcError> {
57 let c_endpoint =
58 CString::new(self.endpoint.as_str()).map_err(|e| RpcError::StringConversion(e))?;
59
60 let mut free: usize = 0;
61 let mut total: usize = 0;
62
63 unsafe {
64 sys::ggml_backend_rpc_get_device_memory(c_endpoint.as_ptr(), &mut free, &mut total);
65 }
66
67 if total == 0 {
68 Err(RpcError::MemoryQueryFailed)
69 } else {
70 Ok((free, total))
71 }
72 }
73
74 pub fn endpoint(&self) -> &str {
76 &self.endpoint
77 }
78
79 pub(crate) fn as_ptr(&self) -> NonNull<sys::ggml_backend> {
81 self.backend
82 }
83}
84
85impl Drop for RpcBackend {
86 fn drop(&mut self) {
87 unsafe {
88 sys::ggml_backend_free(self.backend.as_ptr());
89 }
90 }
91}
92
93unsafe impl Send for RpcBackend {}
95unsafe impl Sync for RpcBackend {}
97
98impl std::fmt::Debug for RpcBackend {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("RpcBackend")
101 .field("endpoint", &self.endpoint)
102 .field("is_rpc", &self.is_rpc())
103 .finish()
104 }
105}