burn_lm_parrot/
lib.rs

1#![recursion_limit = "256"]
2
3use burn_lm_inference::{InferenceJob, *};
4
5// This is where you can declare the configuration parameters for
6// your model.
7//
8// On each field add the `config` attribute to define a default value for it.
9//
10// Those parameters will be available in `burn-lm` CLI.
11//
12// Parameters sent by Open WebUI will be automatically mapped to the configuration
13// parameters with the same name. It also possible to map to an Open WebUI parameter
14// with a different name with `#[config(openwebui_param = "param_name")]`.
15// Examples of Open WebUI parameters are `temperature`, `seed` and `top_p`.
16#[inference_server_config]
17pub struct ParrotServerConfig {
18    /// Temperature value for controlling randomness in sampling.
19    #[config(default = 0.1)]
20    pub temperature: f64,
21}
22
23// Declare the model server info using the `InferenceServer` derive
24// and `inference_server` attribute. The structure must be generic over the
25// Burn backends.
26//
27// Register the model by adding a dependency on this crate in the
28// `burn-lm-registry` crate.
29//
30// Add an entry in the `inference_server_registry` attribute in `lib.rs`
31// of `burn-lm-registry` crate. For instance for this dummy model:
32//
33//     server(
34//         crate_namespace = "burnlm_inference_template",
35//         server_type = "ParrotServer<InferenceBackend>",
36//     ),
37//
38#[derive(InferenceServer, Clone, Default, Debug)]
39#[inference_server(
40    model_name = "Parrot",
41    model_creation_date = "01/28/2025",
42    owned_by = "Tracel Technologies Inc."
43)]
44pub struct ParrotServer<B: Backend> {
45    config: ParrotServerConfig,
46    // Remove the phantom data and add your model here, see TinyLLama example
47    // in burn-lm-llama crate.
48    // You'll likely need to wrap your model in an Arc Mutex because the server
49    // needs to be clonable.
50    _model: std::marker::PhantomData<B>,
51}
52
53// Implement the `InferenceServer` trait for the server.
54impl InferenceServer for ParrotServer<InferenceBackend> {
55    fn downloader(&mut self) -> Option<fn() -> InferenceResult<Option<Stats>>> {
56        // Return a closure with code to download the model if available.
57        // Return none if there is no possibility to download the model or if
58        // this model does not need to be downloaded.
59        None
60    }
61
62    fn is_downloaded(&mut self) -> bool {
63        // This server example does not require downloading
64        // thus is can be considered always installed.
65        // Update accordingly.
66        true
67    }
68
69    fn deleter(&mut self) -> Option<fn() -> InferenceResult<Option<Stats>>> {
70        // Return a closure with code to delete the model if applicable.
71        // Return none if there is no possibility to delete the model or if
72        // this model does not need to be deleted.
73        None
74    }
75
76    fn load(&mut self) -> InferenceResult<Option<Stats>> {
77        // Load the model here
78        let now = std::time::Instant::now();
79        std::thread::sleep(std::time::Duration::from_secs(1));
80        let mut stats = Stats::new();
81        stats
82            .entries
83            .insert(StatEntry::ModelLoadingDuration(now.elapsed()));
84        Ok(Some(stats))
85    }
86
87    fn is_loaded(&mut self) -> bool {
88        // Return true when the model is loaded and ready for inference.
89        true
90    }
91
92    fn unload(&mut self) -> InferenceResult<Option<Stats>> {
93        // Drop the model here.
94        Ok(None)
95    }
96
97    fn run_job(&mut self, job: InferenceJob) -> InferenceResult<Stats> {
98        match job.task {
99            InferenceTask::Message(message) => {
100                job.emitter.completed(GeneratedItem::Text(message.content));
101            }
102            InferenceTask::Context(mut messages) => {
103                job.emitter.completed(GeneratedItem::Text(
104                    messages
105                        .pop()
106                        .map(|msg| msg.content)
107                        .unwrap_or_else(|| "...".to_string()),
108                ));
109            }
110            InferenceTask::Prompt(text) => {
111                job.emitter.completed(GeneratedItem::Text(text));
112            }
113        }
114
115        // Example of returned statistics about the completion.
116        let mut stats = Stats::default();
117        stats
118            .entries
119            .insert(StatEntry::Named("Everything".to_string(), "42".to_string()));
120        Ok(stats)
121    }
122
123    fn clear_state(&mut self) -> InferenceResult<()> {
124        // No state
125        Ok(())
126    }
127}