use std::collections::HashMap;
use crate::network::auth::ConnectionAuthorizationType;
use crate::transport::Connection;
use super::{AuthorizationResult, Authorizer, AuthorizerCallback, AuthorizerError};
pub struct InprocAuthorizer {
endpoint_to_identities: HashMap<String, String>,
node_id: String,
}
impl InprocAuthorizer {
pub fn new<I>(identities: I, node_id: String) -> Self
where
I: IntoIterator<Item = (String, String)>,
{
Self {
endpoint_to_identities: identities.into_iter().collect(),
node_id,
}
}
}
impl Authorizer for InprocAuthorizer {
fn authorize_connection(
&self,
connection_id: String,
connection: Box<dyn Connection>,
on_complete: AuthorizerCallback,
_expected_authorization: Option<ConnectionAuthorizationType>,
_local_authorization: Option<ConnectionAuthorizationType>,
) -> Result<(), AuthorizerError> {
if let Some(identity) = self
.endpoint_to_identities
.get(&connection.remote_endpoint())
.cloned()
{
#[allow(clippy::redundant_clone)]
(*on_complete)(AuthorizationResult::Authorized {
connection_id,
identity: ConnectionAuthorizationType::Trust {
identity: identity.clone(),
},
connection,
expected_authorization: ConnectionAuthorizationType::Trust { identity },
local_authorization: ConnectionAuthorizationType::Trust {
identity: self.node_id.clone(),
},
})
.map_err(|err| AuthorizerError(err.to_string()))
} else {
(*on_complete)(AuthorizationResult::Unauthorized {
connection_id,
connection,
})
.map_err(|err| AuthorizerError(err.to_string()))
}
}
}
#[derive(Default)]
pub struct Authorizers {
authorizers: Vec<(String, Box<dyn Authorizer + Send>)>,
}
impl Authorizers {
pub fn new() -> Self {
Authorizers::default()
}
pub fn add_authorizer(
&mut self,
match_prefix: &str,
authorizer: impl Authorizer + 'static + Send,
) {
self.authorizers
.push((match_prefix.to_string(), Box::new(authorizer)));
}
}
impl Authorizer for Authorizers {
fn authorize_connection(
&self,
connection_id: String,
connection: Box<dyn Connection>,
on_complete: AuthorizerCallback,
expected_authorization: Option<ConnectionAuthorizationType>,
local_authorization: Option<ConnectionAuthorizationType>,
) -> Result<(), AuthorizerError> {
for (match_prefix, authorizer) in &self.authorizers {
if connection.remote_endpoint().starts_with(match_prefix) {
return authorizer.authorize_connection(
connection_id,
connection,
on_complete,
expected_authorization,
local_authorization,
);
}
}
Err(AuthorizerError(format!(
"no authorizer found for {} ({})",
connection_id,
connection.remote_endpoint()
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc;
use crate::transport::{Connection, DisconnectError, RecvError, SendError};
#[test]
fn inproc_configured_authorization() {
let authorizer = InprocAuthorizer::new(
vec![("inproc://test-conn".to_string(), "test-ident1".to_string())],
"node_id".to_string(),
);
let (tx, rx) = mpsc::channel();
authorizer
.authorize_connection(
"abcd-1234".into(),
Box::new(MockConnection::new("inproc://test-conn")),
Box::new(move |result| tx.send(result).map_err(Box::from)),
None,
None,
)
.unwrap();
let result = rx.recv().unwrap();
match result {
AuthorizationResult::Authorized { identity, .. } => {
assert_eq!(
ConnectionAuthorizationType::Trust {
identity: "test-ident1".into()
},
identity
)
}
AuthorizationResult::Unauthorized { .. } => panic!("should have been authorized"),
}
}
#[test]
fn inproc_unconfigured_authorization() {
let authorizer = InprocAuthorizer::new(
vec![("inproc://test-conn".to_string(), "test-ident1".to_string())],
"node_id".to_string(),
);
let (tx, rx) = mpsc::channel();
authorizer
.authorize_connection(
"abcd-1234".into(),
Box::new(MockConnection::new("inproc://bad-inproc-conn")),
Box::new(move |result| tx.send(result).map_err(Box::from)),
None,
None,
)
.unwrap();
let result = rx.recv().unwrap();
match result {
AuthorizationResult::Authorized { .. } => panic!("should not have been authorized"),
AuthorizationResult::Unauthorized { .. } => (),
}
}
#[test]
fn authorizers_configured_authorizations() {
let inproc_authorizer = InprocAuthorizer::new(
vec![("inproc://test-conn".to_string(), "test-ident1".to_string())],
"node_id".to_string(),
);
let future_inproc_authorizer = NoopAuthorizer::new("test-ident2");
let default_authorizer = InprocAuthorizer::new(
vec![(
"protocol://other-conn".to_string(),
"test-ident3".to_string(),
)],
"node_id".to_string(),
);
let mut authorizers = Authorizers::new();
authorizers.add_authorizer("inproc2", future_inproc_authorizer);
authorizers.add_authorizer("inproc", inproc_authorizer);
authorizers.add_authorizer("", default_authorizer);
let (tx, rx) = mpsc::channel();
let tx1 = tx.clone();
authorizers
.authorize_connection(
"abcd-1234".into(),
Box::new(MockConnection::new("inproc://test-conn")),
Box::new(move |result| tx1.send(result).map_err(Box::from)),
None,
None,
)
.unwrap();
let result = rx.recv().unwrap();
match result {
AuthorizationResult::Authorized { identity, .. } => {
assert_eq!(
ConnectionAuthorizationType::Trust {
identity: "test-ident1".into()
},
identity
)
}
AuthorizationResult::Unauthorized { .. } => panic!("should have been authorized"),
}
let tx2 = tx.clone();
authorizers
.authorize_connection(
"abcd-1234".into(),
Box::new(MockConnection::new("inproc2://test-conn")),
Box::new(move |result| tx2.send(result).map_err(Box::from)),
None,
None,
)
.unwrap();
let result = rx.recv().unwrap();
match result {
AuthorizationResult::Authorized { identity, .. } => {
assert_eq!(
ConnectionAuthorizationType::Trust {
identity: "test-ident2".into()
},
identity
)
}
AuthorizationResult::Unauthorized { .. } => panic!("should have been authorized"),
}
let tx3 = tx.clone();
authorizers
.authorize_connection(
"abcd-1234".into(),
Box::new(MockConnection::new("protocol://other-conn")),
Box::new(move |result| tx3.send(result).map_err(Box::from)),
None,
None,
)
.unwrap();
let result = rx.recv().unwrap();
match result {
AuthorizationResult::Authorized { identity, .. } => {
assert_eq!(
ConnectionAuthorizationType::Trust {
identity: "test-ident3".into()
},
identity
)
}
AuthorizationResult::Unauthorized { .. } => panic!("should have been authorized"),
}
let tx4 = tx.clone();
authorizers
.authorize_connection(
"abcd-1234".into(),
Box::new(MockConnection::new("tcp://some-tcp:4444")),
Box::new(move |result| tx4.send(result).map_err(Box::from)),
None,
None,
)
.unwrap();
let result = rx.recv().unwrap();
match result {
AuthorizationResult::Authorized { .. } => panic!("should not have been authorized"),
AuthorizationResult::Unauthorized { .. } => (),
}
}
struct MockConnection {
remote_endpoint: String,
}
impl MockConnection {
fn new(remote_endpoint: &str) -> Self {
Self {
remote_endpoint: remote_endpoint.to_string(),
}
}
}
impl Connection for MockConnection {
fn send(&mut self, _message: &[u8]) -> Result<(), SendError> {
Ok(())
}
fn recv(&mut self) -> Result<Vec<u8>, RecvError> {
unimplemented!()
}
fn remote_endpoint(&self) -> String {
self.remote_endpoint.clone()
}
fn local_endpoint(&self) -> String {
unimplemented!()
}
fn disconnect(&mut self) -> Result<(), DisconnectError> {
Ok(())
}
fn evented(&self) -> &dyn mio::Evented {
unimplemented!()
}
}
struct NoopAuthorizer {
authorized_id: String,
}
impl NoopAuthorizer {
fn new(id: &str) -> Self {
Self {
authorized_id: id.to_string(),
}
}
}
impl Authorizer for NoopAuthorizer {
fn authorize_connection(
&self,
connection_id: String,
connection: Box<dyn Connection>,
callback: AuthorizerCallback,
_expected_authorization: Option<ConnectionAuthorizationType>,
_local_authorization: Option<ConnectionAuthorizationType>,
) -> Result<(), AuthorizerError> {
(*callback)(AuthorizationResult::Authorized {
connection_id,
connection,
identity: ConnectionAuthorizationType::Trust {
identity: self.authorized_id.clone(),
},
expected_authorization: ConnectionAuthorizationType::Trust {
identity: self.authorized_id.clone(),
},
local_authorization: ConnectionAuthorizationType::Trust {
identity: "node_id".to_string(),
},
})
.map_err(|err| AuthorizerError(format!("Unable to return result: {}", err)))
}
}
}