1use std::{
2 collections::HashMap, fs, os::unix::fs::PermissionsExt, path::PathBuf, pin::Pin, sync::Arc,
3};
4
5use futures::{
6 future::{pending, FutureExt as _},
7 lock::Mutex,
8 select,
9 stream::FuturesUnordered,
10 Future, StreamExt as _,
11};
12use log::{debug, info, warn};
13use tokio::net::{unix::UCred, UnixListener};
14
15use krossbar_bus_common::{message::HubMessage, HUB_REGISTER_METHOD};
16use krossbar_common_rpc::{request::RpcRequest, rpc::Rpc, writer::RpcWriter, Error, Result};
17
18use crate::{args::Args, client::Client, permissions::Permissions};
19
20type TasksMapType = FuturesUnordered<Pin<Box<dyn Future<Output = Option<String>> + Send>>>;
21pub type ContextType = Arc<Mutex<HubContext>>;
22
23pub struct HubContext {
24 pub client_registry: HashMap<String, RpcWriter>,
25 pub pending_connections: HashMap<String, Vec<(String, RpcRequest)>>,
26 pub permissions: Permissions,
27}
28
29pub struct Hub {
30 tasks: TasksMapType,
31 socket_path: PathBuf,
32 context: ContextType,
33}
34
35impl Hub {
36 pub fn new(args: Args) -> Self {
37 let tasks: TasksMapType = FuturesUnordered::new();
38 tasks.push(Box::pin(pending()));
39
40 Self {
41 tasks,
42 socket_path: args.socket_path.clone(),
43 context: Arc::new(Mutex::new(HubContext {
44 client_registry: HashMap::new(),
45 pending_connections: HashMap::new(),
46 permissions: Permissions::new(&args.additional_service_dirs),
47 })),
48 }
49 }
50
51 pub async fn run(mut self) {
53 info!("Hub socket path: {:?}", self.socket_path);
54
55 let listener = match UnixListener::bind(&self.socket_path) {
56 Ok(listener) => listener,
57 Err(e) => {
58 warn!("Failed to start listening: {e:?}. Trying to remove hanging socket");
59
60 let _ = std::fs::remove_file(&self.socket_path);
61 let result = UnixListener::bind(&self.socket_path).unwrap();
62
63 result
64 }
65 };
66
67 info!("Hub started listening for new connections");
68
69 let socket_permissions = fs::Permissions::from_mode(0o666);
71 fs::set_permissions(&self.socket_path, socket_permissions).unwrap();
72
73 async move {
74 loop {
75 select! {
76 client = listener.accept().fuse() => {
78 match client {
79 Ok((stream, _)) => {
80 let credentials = stream.peer_cred();
81 let rpc = Rpc::new(stream);
82
83 match credentials {
84 Ok(credentials) => {
85 info!("New connection request: {credentials:?}");
86 let connection = Self::make_new_connection(rpc, credentials, self.context.clone());
87
88 self.tasks.push(Box::pin(connection))
89 },
90 Err(e) => {
91 warn!("Failed to get client creadentials: {}", e.to_string());
92 }
93 }
94
95 },
96 Err(e) => {
97 warn!("Failed client connection attempt: {}", e.to_string())
98 }
99 }
100 },
101 disconnected_service = self.tasks.next() => {
103 let service_name = disconnected_service.unwrap();
104
105 match service_name {
106 Some(service_name) => {
107 debug!("Client disconnected: {}", service_name);
108 self.context.lock().await.client_registry.remove(&service_name);
109 }
110 _ => {
111 debug!("Anonymous client disconnected");
112 }
113 }
114 },
115 _ = tokio::signal::ctrl_c().fuse() => return
116 }
117 }
118 }
119 .await;
120 }
121
122 async fn make_new_connection(
124 mut rpc: Rpc,
125 credentials: UCred,
126 context: ContextType,
127 ) -> Option<String> {
128 let service_name = match rpc.poll().await {
130 Some(mut request) => {
131 if request.endpoint() != HUB_REGISTER_METHOD {
132 request
133 .respond::<()>(Err(Error::InternalError(format!(
134 "Expected registration call from a client. Got {}",
135 request.endpoint()
136 ))))
137 .await;
138 }
139
140 match request.take_body().unwrap() {
141 krossbar_common_rpc::request::Body::Call(bson) => {
143 match bson::from_bson::<HubMessage>(bson) {
145 Ok(HubMessage::Register { service_name }) => {
146 match Self::handle_auth_request(
148 &service_name,
149 &request,
150 credentials,
151 &context,
152 )
153 .await
154 {
155 Ok(_) => {
156 info!("Succesfully authorized {service_name}");
157 request.respond(Ok(())).await;
158
159 service_name
160 }
161 Err(e) => {
162 warn!("Service {service_name} is not allowed to register");
163 request.respond::<()>(Err(e)).await;
164 return None;
165 }
166 }
167 }
168 Ok(m) => {
170 warn!("Invalid registration message from a client: {m:?}");
171
172 request
173 .respond::<()>(Err(Error::InternalError(format!(
174 "Invalid register message body: {m:?}"
175 ))))
176 .await;
177 return None;
178 }
179 Err(e) => {
181 warn!("Invalid Auth message body from a client: {e:?}");
182
183 request
184 .respond::<()>(Err(Error::InternalError(e.to_string())))
185 .await;
186 return None;
187 }
188 }
189 }
190 _ => {
192 warn!("Invalid Auth message from a client (not a call)");
193 return None;
194 }
195 }
196 }
197 _ => {
199 return None;
200 }
201 };
202
203 Some(Client::new(context.clone(), rpc, service_name).run().await)
205 }
206
207 async fn handle_auth_request(
209 service_name: &str,
210 request: &RpcRequest,
211 credentials: UCred,
212 context: &ContextType,
213 ) -> Result<()> {
214 debug!("Service registration request: {}", service_name);
215
216 let mut context_lock = context.lock().await;
217
218 if context_lock.client_registry.contains_key(service_name) {
220 warn!(
221 "Multiple service registration request from: {}",
222 service_name
223 );
224
225 return Err(Error::AlreadyRegistered);
226 } else {
228 if !context_lock
229 .permissions
230 .check_service_name_allowed(credentials, service_name)
231 {
232 debug!("Client {service_name} is not allowed to register with a given credentials");
233
234 return Err(Error::NotAllowed);
235 }
236
237 let mut writer = request.writer().clone();
238 Client::resolve_pending_connections(service_name, &mut writer, &mut context_lock).await;
239
240 context_lock
241 .client_registry
242 .insert(service_name.to_owned(), writer);
243
244 info!("Client authorized as: {}", service_name);
245
246 Ok(())
247 }
248 }
249}