1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
use std::{
any::Any,
collections::{HashMap, HashSet},
fs::File,
io::BufReader,
pin::Pin,
sync::{Arc, atomic::AtomicBool},
};
use actix_web::{
dev::Server,
middleware::{self, from_fn},
rt::System,
web::{self, Data},
App, HttpServer,
};
use async_trait::async_trait;
use bytes::Bytes;
use futures::future::BoxFuture;
use rustls::{RootCertStore, ServerConfig};
use rustls::server::WebPkiClientVerifier;
use serde_json::Value;
use tokio::{
spawn,
io::AsyncRead,
sync::{Mutex, Notify, RwLock, Semaphore},
};
use tokio_stream::Stream;
use crate::{
multibus::{create_bus, MultiBus},
pub_sub,
routes::{process_download, process_request, process_upload, process_metadata, shutdown_from_http, unauthorized, authentication_middleware},
};
/// A trait that all handlers must implement. It provides methods for running
/// business logic, publishing messages, and dispatching messages via a
/// communication line (MultiBus).
#[async_trait]
pub trait Base {
/// Asynchronously runs the handler logic, processing incoming data and using
/// the provided communication line.
///
/// # Arguments
/// * `src` - The source of the data.
/// * `data` - The data to be processed.
/// * `communication_line` - The communication line (MultiBus) for message exchange.
/// * `shared_state` - The variable used to access the shared memory between handlers
///
/// # Returns
/// * A `Result<String, Box<dyn std::error::Error>>` that can either return
/// a string message or an error.
async fn run(&self, src: String, data: String, communication_line: Arc<MultiBus>, shared_state: Arc<SharedState>) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Err("run() not implemented for this handler.".into())
}
/// Processes a readable stream, suitable for streaming data from a Rocket request.
///
/// # Arguments
/// * `src` - The source of the data, usually representing the origin or identifier of the request.
/// * `stream` - A readable stream (e.g., from Rocket's `Data`) for receiving incoming data.
/// * `file_name` - The name of the file to be created from the stream content.
/// * `communication_line` - The communication line (`MultiBus`) for message exchange, enabling interaction between handlers.
/// * `shared_state` - The shared state for accessing shared memory between handlers or other components.
///
/// # Returns
/// * A `Result<String, Box<dyn std::error::Error>>` - On success, returns a `String` indicating success or relevant output.
/// In case of an error, a boxed error type is returned.
async fn run_stream(&self, src: String, stream: Pin<Box<dyn Stream<Item = Bytes> + Send>>, file_name: String, lower_bound: usize, communication_line: Arc<MultiBus>, shared_state: Arc<SharedState>) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Err("run_stream() not implemented for this handler.".into())
}
/// Processes a filename and returns a bytestream and its size.
///
/// # Arguments
/// * `src` - The source of the data.
/// * `filename` - The name of the file to process.
/// * `communication_line` - The communication line for message exchange.
///
/// # Returns
/// * A `Result<(File, u64), Box<dyn std::error::Error>>`, where `File` is a readable stream,
/// and `u64` is the `Content-Length` (file size).
async fn run_file(&self, src: String, filename: String, communication_line: Arc<MultiBus>, shared_state: Arc<SharedState>) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64), Box<dyn std::error::Error + Send + Sync>> {
Err("run_file() not implemented for this handler.".into())
}
/// Returns metadata for a given file name.
///
/// # Arguments
/// * `src` - The source or origin of the request.
/// * `filename` - The name of the file whose metadata is being processed.
/// * `communication_line` - The communication line for message exchange, enabling interaction between components.
/// * `shared_state` - The shared state for accessing shared memory between handlers or other components.
///
/// # Returns
/// * `Result<String, Box<dyn std::error::Error + Send + Sync>>` - On success, a `String` containing
/// metadata information is returned. On failure, a boxed error is returned.
async fn run_metadata(&self, src: String, filename: String, communication_line: Arc<MultiBus>, shared_state: Arc<SharedState>) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Err("run_metadata() not implemented for this handler.".into())
}
/// Publishes a message to a specific target via the communication line.
///
/// # Arguments
/// * `message` - The message to publish.
/// * `src` - The source of the message.
/// * `to` - The target of the message.
/// * `communication_line` - The communication line (MultiBus) for message exchange.
///
/// # Returns
/// * A `String` that contains the data from the response.
async fn publish(&self, message: String, src: String, to: String, communication_line: Arc<MultiBus>) -> String {
let result = pub_sub::publish(src, to, message, communication_line).await;
let parsed: Value = serde_json::from_str(&result).unwrap();
parsed["data"].to_string()
}
/// Dispatches a message to a specific target without awaiting a response.
///
/// # Arguments
/// * `message` - The message to dispatch.
/// * `to` - The target to which the message is dispatched.
/// * `communication_line` - The communication line (MultiBus) for message exchange.
async fn dispatch(&self, message: String, to: String, communication_line: Arc<MultiBus>) {
pub_sub::dispatch("not important".to_string(),to, message, communication_line).await;
}
/// Factory method for creating a new instance of the struct implementing the trait.
fn new() -> Self where Self: Sized;
}
/// Represents the different types of states that can be stored.
/// It can hold primitive types, strings, or even a function wrapped in `Arc` for thread-safe access.
#[derive(Clone)]
pub enum StateType {
/// Represents an empty or uninitialized state.
None,
/// Holds a sync function that takes a `String` as input and returns a `String`.
FunctionSync(Arc<dyn Fn(String) -> String + Sync + Send>),
/// Holds an async function that takes a `String` as input and returns a `String`
FunctionAsync(Arc<dyn Fn(String) -> BoxFuture<'static, String> + Sync + Send>),
/// Holds a `String` value.
String(String),
/// Holds a 32-bit integer.
Int(i32),
/// Holds a 32-bit floating-point number.
Float(f32),
/// Holds a 64-bit integer.
Long(i64),
/// Holds a 64-bit floating-point number.
Double(f64),
/// Holds any type which is safe
AnyType(Arc<dyn Any + Sync + Send>)
}
/// `SharedState` is a thread-safe structure for storing and managing different types of states.
/// It uses an `RwLock` to allow safe concurrent access across threads.
pub struct SharedState {
/// A `HashMap` that stores key-value pairs, where the key is a `String`
/// and the value is of type `StateType`.
elements: RwLock<HashMap<String, StateType>>,
}
impl SharedState {
/// Retrieves a value from the shared state given a key.
///
/// # Arguments
/// * `key` - A reference to the key (`String`) used to lookup the value.
///
/// # Returns
/// * `StateType` - The value associated with the key, or `StateType::None` if the key is not found.
///
/// ```
pub async fn get(&self, key: &String) -> StateType {
let elem_lock = self.elements.read().await;
let result = elem_lock.get(key);
match result {
None => {
StateType::None
},
Some(value) => {
value.clone()
}
}
}
/// Inserts a key-value pair into the shared state.
///
/// # Arguments
/// * `key` - A reference to the key (`String`) for the value.
/// * `data` - The `StateType` value to be stored in the shared state.
///
/// ```
pub async fn insert(&self, key: &String, data: StateType) {
let mut elem_lock = self.elements.write().await;
elem_lock.insert(key.clone(), data.clone());
drop(elem_lock);
}
/// Deletes a key-value pair from the shared state.
///
/// # Arguments
/// * `key` - A reference to the key (`String`) to remove from the shared state.
///
/// ```
pub async fn delete(&self, key: &String) {
let mut elem_lock = self.elements.write().await;
elem_lock.remove(key);
drop(elem_lock);
}
/// Attempts to downcast the stored `AnyType` value to the specified type.
///
/// # Arguments
/// * `key` - A reference to the key (`String`) used to lookup the value.
///
/// # Returns
/// * `Option<Arc<T>>` - The downcast value if successful, or `None` if downcast fails.
pub async fn get_any<T: 'static + Send + Sync>(&self, key: &String) -> Option<Arc<T>>{
let elem_lock = self.elements.read().await;
match elem_lock.get(key) {
Some(StateType::AnyType(value)) => value.clone().downcast::<T>().ok(),
_ => None
}
}
}
#[derive(Clone)]
pub struct Config {
pub(crate) api_key: String,
}
/// Manages the lifecycle of the registered handlers, communication lines, and listeners.
pub struct Manager {
instance: HashMap<String, Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>>>,
number_replicas: HashMap<String, i32>,
communication_line: Arc<MultiBus>,
listeners: Vec<tokio::task::JoinHandle<()>>,
nr_requests: i32,
shared_state: Arc<SharedState>,
api_key: String,
cert_path: Option<String>,
key_path: Option<String>,
ca_path: Option<String>,
allowed_names: Option<Vec<String>>
}
impl Manager {
/// Creates a new instance of `Manager`.
/// # Arguments
/// * `nr_requests` - The number of requests it can handle simultaneously
/// # Returns
/// * A `Manager` with no registered handlers, initialized listeners, and communication line.
pub fn new(nr_requests: i32, api_key: String, cert_path: Option<String>, key_path: Option<String>, ca_path: Option<String>, allowed_names: Option<Vec<String>>) -> Manager {
Manager {
instance: HashMap::new(),
number_replicas: HashMap::new(),
communication_line: create_bus(),
listeners: Vec::new(),
nr_requests,
shared_state: Arc::new(SharedState { elements: Default::default() }),
api_key,
cert_path,
key_path,
ca_path,
allowed_names
}
}
/// Registers a new handler to the manager.
///
/// # Arguments
/// * `name` - The name of the handler.
/// * `nr_replicas` - The number of replicas a handler can have at any time
///
/// # Panics
/// * Panics if the handler with the same name is already added.
pub fn add_handler<T>(&mut self, name: &str, nr_replicas: i32)
where T: Base + Send + Sync + 'static {
if self.instance.contains_key(name) {
panic!("Can't add handler {} because it is already added", name);
}
let factory: Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync> = Box::new(|| {
Box::new(T::new())
});
self.instance.insert(name.to_string(), Arc::new(factory));
self.number_replicas.insert(name.to_string(), nr_replicas);
}
pub async fn add_global_state(&mut self, key: &str, state_type: StateType) {
self.shared_state.insert(&key.to_owned(), state_type).await;
}
/// Starts the manager, initializes the handlers, and launches the Rocket server.
///
/// This method also sets up listeners for incoming messages for each handler.
///
/// # Panics
/// * Panics if the Rocket server fails to launch.
pub async fn start(&mut self) {
println!("Initializing the listeners for the handlers");
pub_sub::setup_publishing("manager".to_owned(), self.communication_line.clone()).await;
let shared_state = self.shared_state.clone();
for (name, instance) in self.instance.iter() {
pub_sub::setup_publishing(name.clone(), self.communication_line.clone()).await;
let communication_line = Arc::clone(&self.communication_line);
let instance_to_run: Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>> = Arc::clone(instance);
let nr_times = self.number_replicas.get(name).unwrap().clone();
let name = name.clone();
let shared_state_clone = shared_state.clone();
let handle = spawn(async move {
let nr_times_cn = nr_times.clone().to_owned();
let communication_line_clone_2 = communication_line.clone();
let standby_handlers: Arc<Mutex<Vec<i32>>> = Arc::new(Mutex::new((0..nr_times_cn).collect()));
let notification = Arc::new(Notify::new());
let name_cn = name.clone();
let shared_state_clone_2 = shared_state_clone.clone();
while let Some(data) = communication_line.clone().request_data(name.clone()).await {
let standby_handlers_i = standby_handlers.clone();
let name_cn_i = name_cn.clone();
let communication_line_i = communication_line_clone_2.clone();
let instance_run_i = instance_to_run.clone();
let notification_i = notification.clone();
let shared_state_clone_i = shared_state_clone_2.clone();
let parsed_data: Value = serde_json::from_str(&data).unwrap();
let data_for_instance = parsed_data["data"]
.as_str()
.expect("Should have at least an empty string")
.to_string();
let mut standby_handlers_lock = standby_handlers_i.lock().await;
let mut id = standby_handlers_lock.pop();
drop(standby_handlers_lock);
if id == None {
notification.clone().notified().await;
standby_handlers_lock = standby_handlers_i.lock().await;
id = standby_handlers_lock.pop();
drop(standby_handlers_lock);
}
let id_unwrapped = id.unwrap();
spawn(async move {
let result = instance_run_i().run(name_cn_i.clone(), data_for_instance, communication_line_i.clone(), shared_state_clone_i.clone()).await;
if let Err(e) = result {
eprintln!("Errored while running handler {}: {}", name_cn_i, e);
} else {
let value_to_return = result.unwrap();
if parsed_data["type"].as_str().unwrap() == "publish" {
pub_sub::dispatch(name_cn_i.to_string(), parsed_data["src"].as_str().unwrap().to_string(), value_to_return, communication_line_i.clone()).await;
}
}
let mut standby_handlers_lock = standby_handlers_i.lock().await;
standby_handlers_lock.push(id_unwrapped);
drop(standby_handlers_lock);
notification_i.notify_waiters();
});
}
});
self.listeners.push(handle);
}
println!("Initializing the server listener");
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.unwrap();
let mut tls_config: Option<ServerConfig> = None;
if let Some(cert_path) = self.cert_path.clone() {
if let Some(key_cert) = self.key_path.clone() {
let mut certs_file = BufReader::new(File::open(cert_path).unwrap());
let mut key_file = BufReader::new(File::open(key_cert).unwrap());
let tls_certs = rustls_pemfile::certs(&mut certs_file)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let tls_key = rustls_pemfile::rsa_private_keys(&mut key_file)
.next()
.unwrap()
.unwrap();
if let Some(ca_path) = self.ca_path.clone() {
let mut ca_file = BufReader::new(File::open(ca_path).unwrap());
let ca_certs = rustls_pemfile::certs(&mut ca_file)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut root_store = RootCertStore::empty();
for cert in ca_certs {
root_store.add(cert).unwrap();
}
let client_cert_verifier = WebPkiClientVerifier::builder(<Arc<RootCertStore>>::from(root_store))
.build()
.expect("Failed to create client certificate verifier");
tls_config = Some(ServerConfig::builder()
.with_client_cert_verifier(client_cert_verifier)
.with_single_cert(tls_certs, rustls::pki_types::PrivateKeyDer::Pkcs1(tls_key))
.unwrap());
} else {
tls_config = Some(ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(tls_certs, rustls::pki_types::PrivateKeyDer::Pkcs1(tls_key))
.unwrap())
}
}
}
let end_notifier = Arc::new(Notify::new());
let one_request_at_a_time = Arc::new(Semaphore::new(self.nr_requests as usize));
let has_been_called = Arc::new(AtomicBool::new(false));
let api_key = Config {api_key: self.api_key.clone()};
let request_max_per_handler = Arc::new(self.number_replicas
.iter()
.map(|val| (val.0.clone(), Arc::new(Semaphore::new(*val.1 as usize))))
.collect::<HashMap<String, Arc<Semaphore>>>());
let mut allowed_names_set: Arc<Option<HashSet<String>>> = Arc::new(None);
if let Some(allowed_names) = self.allowed_names.clone() {
allowed_names_set = Arc::new(Some(HashSet::from_iter(allowed_names)));
}
let instance_clone = Arc::new(self.instance.clone());
let communication_line_clone = self.communication_line.clone();
let end_notifier_clone = end_notifier.clone();
unsafe {
std::env::set_var("RUST_LOG", "debug");
}
env_logger::init();
let server_cfg = HttpServer::new(move || {
App::new()
.app_data(web::PayloadConfig::new(10 * 1024 * 1024 * 1024)) // 10 GiB
.app_data(web::JsonConfig::default().limit(1024 * 1024 * 1024)) // 1 GiB
.app_data(Data::new(end_notifier_clone.clone()))
.app_data(Data::new(one_request_at_a_time.clone()))
.app_data(Data::new(request_max_per_handler.clone()))
.app_data(Data::new(has_been_called.clone()))
.app_data(Data::new(api_key.clone()))
.app_data(Data::new(instance_clone.clone()))
.app_data(Data::new(communication_line_clone.clone()))
.app_data(Data::new(shared_state.clone()))
.app_data(Data::new(allowed_names_set.clone()))
.wrap(from_fn(authentication_middleware))
.wrap(middleware::Logger::default())
.service(
web::resource("/shutdown")
.to(shutdown_from_http)
)
.service(
web::resource("/{handler_name}")
.to(process_request)
)
.service(
web::scope("/stream")
.service(web::resource("/upload/{file_name}").to(process_upload))
.service(web::resource("/download/{file_id}").to(process_download))
.service(web::resource("/metadata/{file_id}").to(process_metadata))
)
.default_service(
web::route()
.to(unauthorized)
)
});
let mut server: Server;
if let Some(tls) = tls_config {
server = server_cfg.bind_rustls_0_23(("0.0.0.0", 8080), tls).unwrap().run();
} else {
server = server_cfg.bind(("0.0.0.0", 8080)).unwrap().run();
}
tokio::select! {
_ = server => {},
_ = end_notifier.notified() => {
System::current().stop();
}
}
}
/// Forcefully terminates all listeners managed by the `Manager`.
///
/// This function aborts all tasks without waiting for them to finish, ensuring
/// an immediate stop of all handlers and listeners.
pub fn force_finish_all(&mut self) {
self.listeners.iter().for_each(|elem| {
elem.abort();
});
self.listeners.clear();
}
}