burn_lm_inference/
client.rs1use std::marker::PhantomData;
2
3#[cfg(not(feature = "legacy-v018"))]
4use burn::prelude::Backend;
5
6use crate::{
7 channels::InferenceChannel,
8 errors::InferenceResult,
9 plugin::{CreateCliFlagsFn, InferencePlugin},
10 server::InferenceServer,
11 InferenceJob, Stats,
12};
13
14#[derive(Debug, Clone)]
15pub struct InferenceClient<Server: InferenceServer + 'static, Channel: 'static> {
16 model_name: &'static str,
17 model_cli_param_name: &'static str,
18 model_creation_date: &'static str,
19 owned_by: &'static str,
20 create_cli_flags_fn: CreateCliFlagsFn,
21 channel: Channel,
22 _phantom_server: PhantomData<Server>,
23}
24
25impl<Server, Channel> InferenceClient<Server, Channel>
26where
27 Server: InferenceServer,
28 Channel: InferenceChannel<Server>,
29{
30 #[allow(clippy::too_many_arguments)]
31 pub fn new(
32 model_name: &'static str,
33 model_cli_param_name: &'static str,
34 model_creation_date: &'static str,
35 owned_by: &'static str,
36 create_cli_flags_fn: CreateCliFlagsFn,
37 channel: Channel,
38 ) -> Self {
39 Self {
40 model_name,
41 model_cli_param_name,
42 model_creation_date,
43 owned_by,
44 create_cli_flags_fn,
45 channel,
46 _phantom_server: PhantomData,
47 }
48 }
49}
50
51impl<Server, Channel> InferencePlugin for InferenceClient<Server, Channel>
52where
53 Server: InferenceServer + 'static,
54 Channel: InferenceChannel<Server> + 'static,
55{
56 fn clone_box(&self) -> Box<dyn InferencePlugin> {
57 Box::new(self.clone())
58 }
59
60 fn downloader(&self) -> Option<fn() -> InferenceResult<Option<Stats>>> {
61 self.channel.downloader()
62 }
63
64 fn is_downloaded(&self) -> bool {
65 self.channel.is_downloaded()
66 }
67
68 fn deleter(&self) -> Option<fn() -> InferenceResult<Option<Stats>>> {
69 let result = self.channel.deleter();
70
71 #[cfg(not(feature = "legacy-v018"))]
72 let device = &crate::INFERENCE_DEVICE;
73
74 #[cfg(not(feature = "legacy-v018"))]
75 <crate::InferenceBackend as Backend>::memory_cleanup(device);
76
77 result
78 }
79
80 fn parse_cli_config(&self, args: &clap::ArgMatches) {
81 self.channel.parse_cli_config(args);
82 }
83
84 fn parse_json_config(&self, json: &str) {
85 self.channel.parse_json_config(json);
86 }
87
88 fn load(&self) -> InferenceResult<Option<Stats>> {
89 cfg_if::cfg_if! {
90 if #[cfg(not(feature = "legacy-v018"))] {
91 let device = &crate::INFERENCE_DEVICE;
92 <crate::InferenceBackend as Backend>::memory_static_allocations(device, (), |_| {
93 self.channel.load()
94 })
95 } else {
96 self.channel.load()
97 }
98 }
99 }
100
101 fn is_loaded(&self) -> bool {
102 self.channel.is_loaded()
103 }
104
105 fn unload(&self) -> InferenceResult<Option<Stats>> {
106 let result = self.channel.unload();
107
108 #[cfg(not(feature = "legacy-v018"))]
109 let device = &crate::INFERENCE_DEVICE;
110
111 #[cfg(not(feature = "legacy-v018"))]
112 <crate::InferenceBackend as Backend>::memory_cleanup(device);
113
114 result
115 }
116
117 fn run_job(&self, job: InferenceJob) -> InferenceResult<Stats> {
118 self.channel.run_job(job)
119 }
120
121 fn model_name(&self) -> &'static str {
122 self.model_name
123 }
124
125 fn model_cli_param_name(&self) -> &'static str {
126 self.model_cli_param_name
127 }
128
129 fn model_creation_date(&self) -> &'static str {
130 self.model_creation_date
131 }
132
133 fn owned_by(&self) -> &'static str {
134 self.owned_by
135 }
136
137 fn create_cli_flags_fn(&self) -> CreateCliFlagsFn {
138 self.create_cli_flags_fn
139 }
140
141 fn clear_state(&self) -> InferenceResult<()> {
142 self.channel.clear_state()
143 }
144}