burn_lm_inference/
client.rs

1use std::marker::PhantomData;
2
3#[cfg(not(feature = "legacy-v018"))]
4use burn::prelude::Backend;
5
6use crate::{
7    channels::InferenceChannel,
8    errors::InferenceResult,
9    plugin::{CreateCliFlagsFn, InferencePlugin},
10    server::InferenceServer,
11    InferenceJob, Stats,
12};
13
14#[derive(Debug, Clone)]
15pub struct InferenceClient<Server: InferenceServer + 'static, Channel: 'static> {
16    model_name: &'static str,
17    model_cli_param_name: &'static str,
18    model_creation_date: &'static str,
19    owned_by: &'static str,
20    create_cli_flags_fn: CreateCliFlagsFn,
21    channel: Channel,
22    _phantom_server: PhantomData<Server>,
23}
24
25impl<Server, Channel> InferenceClient<Server, Channel>
26where
27    Server: InferenceServer,
28    Channel: InferenceChannel<Server>,
29{
30    #[allow(clippy::too_many_arguments)]
31    pub fn new(
32        model_name: &'static str,
33        model_cli_param_name: &'static str,
34        model_creation_date: &'static str,
35        owned_by: &'static str,
36        create_cli_flags_fn: CreateCliFlagsFn,
37        channel: Channel,
38    ) -> Self {
39        Self {
40            model_name,
41            model_cli_param_name,
42            model_creation_date,
43            owned_by,
44            create_cli_flags_fn,
45            channel,
46            _phantom_server: PhantomData,
47        }
48    }
49}
50
51impl<Server, Channel> InferencePlugin for InferenceClient<Server, Channel>
52where
53    Server: InferenceServer + 'static,
54    Channel: InferenceChannel<Server> + 'static,
55{
56    fn clone_box(&self) -> Box<dyn InferencePlugin> {
57        Box::new(self.clone())
58    }
59
60    fn downloader(&self) -> Option<fn() -> InferenceResult<Option<Stats>>> {
61        self.channel.downloader()
62    }
63
64    fn is_downloaded(&self) -> bool {
65        self.channel.is_downloaded()
66    }
67
68    fn deleter(&self) -> Option<fn() -> InferenceResult<Option<Stats>>> {
69        let result = self.channel.deleter();
70
71        #[cfg(not(feature = "legacy-v018"))]
72        let device = &crate::INFERENCE_DEVICE;
73
74        #[cfg(not(feature = "legacy-v018"))]
75        <crate::InferenceBackend as Backend>::memory_cleanup(device);
76
77        result
78    }
79
80    fn parse_cli_config(&self, args: &clap::ArgMatches) {
81        self.channel.parse_cli_config(args);
82    }
83
84    fn parse_json_config(&self, json: &str) {
85        self.channel.parse_json_config(json);
86    }
87
88    fn load(&self) -> InferenceResult<Option<Stats>> {
89        cfg_if::cfg_if! {
90            if #[cfg(not(feature = "legacy-v018"))] {
91                let device = &crate::INFERENCE_DEVICE;
92                <crate::InferenceBackend as Backend>::memory_static_allocations(device, (), |_| {
93                    self.channel.load()
94                })
95            } else {
96                self.channel.load()
97            }
98        }
99    }
100
101    fn is_loaded(&self) -> bool {
102        self.channel.is_loaded()
103    }
104
105    fn unload(&self) -> InferenceResult<Option<Stats>> {
106        let result = self.channel.unload();
107
108        #[cfg(not(feature = "legacy-v018"))]
109        let device = &crate::INFERENCE_DEVICE;
110
111        #[cfg(not(feature = "legacy-v018"))]
112        <crate::InferenceBackend as Backend>::memory_cleanup(device);
113
114        result
115    }
116
117    fn run_job(&self, job: InferenceJob) -> InferenceResult<Stats> {
118        self.channel.run_job(job)
119    }
120
121    fn model_name(&self) -> &'static str {
122        self.model_name
123    }
124
125    fn model_cli_param_name(&self) -> &'static str {
126        self.model_cli_param_name
127    }
128
129    fn model_creation_date(&self) -> &'static str {
130        self.model_creation_date
131    }
132
133    fn owned_by(&self) -> &'static str {
134        self.owned_by
135    }
136
137    fn create_cli_flags_fn(&self) -> CreateCliFlagsFn {
138        self.create_cli_flags_fn
139    }
140
141    fn clear_state(&self) -> InferenceResult<()> {
142        self.channel.clear_state()
143    }
144}