burn_lm_parrot/lib.rs
1#![recursion_limit = "256"]
2
3use burn_lm_inference::{InferenceJob, *};
4
5// This is where you can declare the configuration parameters for
6// your model.
7//
8// On each field add the `config` attribute to define a default value for it.
9//
10// Those parameters will be available in `burn-lm` CLI.
11//
12// Parameters sent by Open WebUI will be automatically mapped to the configuration
13// parameters with the same name. It also possible to map to an Open WebUI parameter
14// with a different name with `#[config(openwebui_param = "param_name")]`.
15// Examples of Open WebUI parameters are `temperature`, `seed` and `top_p`.
16#[inference_server_config]
17pub struct ParrotServerConfig {
18 /// Temperature value for controlling randomness in sampling.
19 #[config(default = 0.1)]
20 pub temperature: f64,
21}
22
23// Declare the model server info using the `InferenceServer` derive
24// and `inference_server` attribute. The structure must be generic over the
25// Burn backends.
26//
27// Register the model by adding a dependency on this crate in the
28// `burn-lm-registry` crate.
29//
30// Add an entry in the `inference_server_registry` attribute in `lib.rs`
31// of `burn-lm-registry` crate. For instance for this dummy model:
32//
33// server(
34// crate_namespace = "burnlm_inference_template",
35// server_type = "ParrotServer<InferenceBackend>",
36// ),
37//
38#[derive(InferenceServer, Clone, Default, Debug)]
39#[inference_server(
40 model_name = "Parrot",
41 model_creation_date = "01/28/2025",
42 owned_by = "Tracel Technologies Inc."
43)]
44pub struct ParrotServer<B: Backend> {
45 config: ParrotServerConfig,
46 // Remove the phantom data and add your model here, see TinyLLama example
47 // in burn-lm-llama crate.
48 // You'll likely need to wrap your model in an Arc Mutex because the server
49 // needs to be clonable.
50 _model: std::marker::PhantomData<B>,
51}
52
53// Implement the `InferenceServer` trait for the server.
54impl InferenceServer for ParrotServer<InferenceBackend> {
55 fn downloader(&mut self) -> Option<fn() -> InferenceResult<Option<Stats>>> {
56 // Return a closure with code to download the model if available.
57 // Return none if there is no possibility to download the model or if
58 // this model does not need to be downloaded.
59 None
60 }
61
62 fn is_downloaded(&mut self) -> bool {
63 // This server example does not require downloading
64 // thus is can be considered always installed.
65 // Update accordingly.
66 true
67 }
68
69 fn deleter(&mut self) -> Option<fn() -> InferenceResult<Option<Stats>>> {
70 // Return a closure with code to delete the model if applicable.
71 // Return none if there is no possibility to delete the model or if
72 // this model does not need to be deleted.
73 None
74 }
75
76 fn load(&mut self) -> InferenceResult<Option<Stats>> {
77 // Load the model here
78 let now = std::time::Instant::now();
79 std::thread::sleep(std::time::Duration::from_secs(1));
80 let mut stats = Stats::new();
81 stats
82 .entries
83 .insert(StatEntry::ModelLoadingDuration(now.elapsed()));
84 Ok(Some(stats))
85 }
86
87 fn is_loaded(&mut self) -> bool {
88 // Return true when the model is loaded and ready for inference.
89 true
90 }
91
92 fn unload(&mut self) -> InferenceResult<Option<Stats>> {
93 // Drop the model here.
94 Ok(None)
95 }
96
97 fn run_job(&mut self, job: InferenceJob) -> InferenceResult<Stats> {
98 match job.task {
99 InferenceTask::Message(message) => {
100 job.emitter.completed(GeneratedItem::Text(message.content));
101 }
102 InferenceTask::Context(mut messages) => {
103 job.emitter.completed(GeneratedItem::Text(
104 messages
105 .pop()
106 .map(|msg| msg.content)
107 .unwrap_or_else(|| "...".to_string()),
108 ));
109 }
110 InferenceTask::Prompt(text) => {
111 job.emitter.completed(GeneratedItem::Text(text));
112 }
113 }
114
115 // Example of returned statistics about the completion.
116 let mut stats = Stats::default();
117 stats
118 .entries
119 .insert(StatEntry::Named("Everything".to_string(), "42".to_string()));
120 Ok(stats)
121 }
122
123 fn clear_state(&mut self) -> InferenceResult<()> {
124 // No state
125 Ok(())
126 }
127}