krossbar_hub_lib/
hub.rs

1use std::{
2    collections::HashMap, fs, os::unix::fs::PermissionsExt, path::PathBuf, pin::Pin, sync::Arc,
3};
4
5use futures::{
6    future::{pending, FutureExt as _},
7    lock::Mutex,
8    select,
9    stream::FuturesUnordered,
10    Future, StreamExt as _,
11};
12use log::{debug, info, warn};
13use tokio::net::{unix::UCred, UnixListener};
14
15use krossbar_bus_common::{message::HubMessage, HUB_REGISTER_METHOD};
16use krossbar_common_rpc::{request::RpcRequest, rpc::Rpc, writer::RpcWriter, Error, Result};
17
18use crate::{args::Args, client::Client, permissions::Permissions};
19
20type TasksMapType = FuturesUnordered<Pin<Box<dyn Future<Output = Option<String>> + Send>>>;
21pub type ContextType = Arc<Mutex<HubContext>>;
22
23pub struct HubContext {
24    pub client_registry: HashMap<String, RpcWriter>,
25    pub pending_connections: HashMap<String, Vec<(String, RpcRequest)>>,
26    pub permissions: Permissions,
27}
28
29pub struct Hub {
30    tasks: TasksMapType,
31    socket_path: PathBuf,
32    context: ContextType,
33}
34
35impl Hub {
36    pub fn new(args: Args) -> Self {
37        let tasks: TasksMapType = FuturesUnordered::new();
38        tasks.push(Box::pin(pending()));
39
40        Self {
41            tasks,
42            socket_path: args.socket_path.clone(),
43            context: Arc::new(Mutex::new(HubContext {
44                client_registry: HashMap::new(),
45                pending_connections: HashMap::new(),
46                permissions: Permissions::new(&args.additional_service_dirs),
47            })),
48        }
49    }
50
51    /// Hub main loop
52    pub async fn run(mut self) {
53        info!("Hub socket path: {:?}", self.socket_path);
54
55        let listener = match UnixListener::bind(&self.socket_path) {
56            Ok(listener) => listener,
57            Err(e) => {
58                warn!("Failed to start listening: {e:?}. Trying to remove hanging socket");
59
60                let _ = std::fs::remove_file(&self.socket_path);
61                let result = UnixListener::bind(&self.socket_path).unwrap();
62
63                result
64            }
65        };
66
67        info!("Hub started listening for new connections");
68
69        // Update permissions to be accessible for th eclient
70        let socket_permissions = fs::Permissions::from_mode(0o666);
71        fs::set_permissions(&self.socket_path, socket_permissions).unwrap();
72
73        async move {
74            loop {
75                select! {
76                    // Accept new connection requests
77                    client = listener.accept().fuse() => {
78                        match client {
79                            Ok((stream, _)) => {
80                                let credentials = stream.peer_cred();
81                                let rpc = Rpc::new(stream);
82
83                                match credentials {
84                                    Ok(credentials) => {
85                                        info!("New connection request: {credentials:?}");
86                                        let connection = Self::make_new_connection(rpc, credentials, self.context.clone());
87
88                                        self.tasks.push(Box::pin(connection))
89                                    },
90                                    Err(e) => {
91                                        warn!("Failed to get client creadentials: {}", e.to_string());
92                                    }
93                                }
94
95                            },
96                            Err(e) => {
97                                warn!("Failed client connection attempt: {}", e.to_string())
98                            }
99                        }
100                    },
101                    // Loop clients. Loop return means a client is disconnected
102                    disconnected_service = self.tasks.next() => {
103                        let service_name = disconnected_service.unwrap();
104
105                        match service_name {
106                            Some(service_name) => {
107                                debug!("Client disconnected: {}", service_name);
108                                self.context.lock().await.client_registry.remove(&service_name);
109                            }
110                            _ => {
111                                debug!("Anonymous client disconnected");
112                            }
113                        }
114                    },
115                    _ = tokio::signal::ctrl_c().fuse() => return
116                }
117            }
118        }
119        .await;
120    }
121
122    /// Make a connection form a stream
123    async fn make_new_connection(
124        mut rpc: Rpc,
125        credentials: UCred,
126        context: ContextType,
127    ) -> Option<String> {
128        // Authorize the client
129        let service_name = match rpc.poll().await {
130            Some(mut request) => {
131                if request.endpoint() != HUB_REGISTER_METHOD {
132                    request
133                        .respond::<()>(Err(Error::InternalError(format!(
134                            "Expected registration call from a client. Got {}",
135                            request.endpoint()
136                        ))))
137                        .await;
138                }
139
140                match request.take_body().unwrap() {
141                    // Valid call message
142                    krossbar_common_rpc::request::Body::Call(bson) => {
143                        // Valid Auth message
144                        match bson::from_bson::<HubMessage>(bson) {
145                            Ok(HubMessage::Register { service_name }) => {
146                                // Check permissions
147                                match Self::handle_auth_request(
148                                    &service_name,
149                                    &request,
150                                    credentials,
151                                    &context,
152                                )
153                                .await
154                                {
155                                    Ok(_) => {
156                                        info!("Succesfully authorized {service_name}");
157                                        request.respond(Ok(())).await;
158
159                                        service_name
160                                    }
161                                    Err(e) => {
162                                        warn!("Service {service_name} is not allowed to register");
163                                        request.respond::<()>(Err(e)).await;
164                                        return None;
165                                    }
166                                }
167                            }
168                            // Connection request instead of an Auth message
169                            Ok(m) => {
170                                warn!("Invalid registration message from a client: {m:?}");
171
172                                request
173                                    .respond::<()>(Err(Error::InternalError(format!(
174                                        "Invalid register message body: {m:?}"
175                                    ))))
176                                    .await;
177                                return None;
178                            }
179                            // Message deserialization error
180                            Err(e) => {
181                                warn!("Invalid Auth message body from a client: {e:?}");
182
183                                request
184                                    .respond::<()>(Err(Error::InternalError(e.to_string())))
185                                    .await;
186                                return None;
187                            }
188                        }
189                    }
190                    // Not a call, but respond, of FD or other irrelevant message
191                    _ => {
192                        warn!("Invalid Auth message from a client (not a call)");
193                        return None;
194                    }
195                }
196            }
197            // Client disconnected
198            _ => {
199                return None;
200            }
201        };
202
203        // Cient succesfully authorized. Start client loop
204        Some(Client::new(context.clone(), rpc, service_name).run().await)
205    }
206
207    /// Handle client Auth message
208    async fn handle_auth_request(
209        service_name: &str,
210        request: &RpcRequest,
211        credentials: UCred,
212        context: &ContextType,
213    ) -> Result<()> {
214        debug!("Service registration request: {}", service_name);
215
216        let mut context_lock = context.lock().await;
217
218        // Check if we already have a client with the name
219        if context_lock.client_registry.contains_key(service_name) {
220            warn!(
221                "Multiple service registration request from: {}",
222                service_name
223            );
224
225            return Err(Error::AlreadyRegistered);
226        // The only valid Auth request path
227        } else {
228            if !context_lock
229                .permissions
230                .check_service_name_allowed(credentials, service_name)
231            {
232                debug!("Client {service_name} is not allowed to register with a given credentials");
233
234                return Err(Error::NotAllowed);
235            }
236
237            let mut writer = request.writer().clone();
238            Client::resolve_pending_connections(service_name, &mut writer, &mut context_lock).await;
239
240            context_lock
241                .client_registry
242                .insert(service_name.to_owned(), writer);
243
244            info!("Client authorized as: {}", service_name);
245
246            Ok(())
247        }
248    }
249}