use serde::{Deserialize, Serialize};
use zlink::{
Server,
introspect::{self, Type},
unix::{bind, connect},
};
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn multiple_interfaces() -> Result<(), Box<dyn std::error::Error>> {
let socket_path = "/tmp/zlink-service-macro-multi-iface-test.sock";
if let Err(e) = tokio::fs::remove_file(socket_path).await {
if e.kind() != std::io::ErrorKind::NotFound {
return Err(e.into());
}
}
let listener = bind(socket_path).unwrap();
let service = MultiInterfaceService {
user_authenticated: false,
items: vec!["apple".to_string(), "banana".to_string()],
};
let server = Server::new(listener, service);
tokio::select! {
res = server.run() => res?,
res = run_client(socket_path) => res?,
}
Ok(())
}
async fn run_client(socket_path: &str) -> Result<(), Box<dyn std::error::Error>> {
let mut conn = connect(socket_path).await?;
let err = conn.get_user_info().await?.unwrap_err();
assert_eq!(err, AuthError::NotAuthenticated);
conn.authenticate("secret".to_string()).await?.unwrap();
let err = conn.authenticate("wrong".to_string()).await?.unwrap_err();
assert_eq!(
err,
AuthError::InvalidCredentials {
reason: "wrong password".to_string()
}
);
let info = conn.get_user_info().await?.unwrap();
assert_eq!(info.name, "TestUser");
let count = conn.item_count().await?.unwrap();
assert_eq!(count.count, 2);
let err = conn.get_item(10).await?.unwrap_err();
assert_eq!(err, StorageError::NotFound);
let item = conn.get_item(0).await?.unwrap();
assert_eq!(item.value, "apple");
let err = conn.add_item("cherry".to_string()).await?.unwrap_err();
assert_eq!(err, StorageError::QuotaExceeded { limit: 2 });
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Type)]
struct ItemCount {
count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Type)]
struct UserInfo {
name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Type)]
struct Item {
value: String,
}
#[derive(Debug, Clone, PartialEq, zlink::ReplyError, introspect::ReplyError)]
#[zlink(interface = "org.example.auth")]
enum AuthError {
NotAuthenticated,
InvalidCredentials { reason: String },
}
#[derive(Debug, Clone, PartialEq, zlink::ReplyError, introspect::ReplyError)]
#[zlink(interface = "org.example.storage")]
enum StorageError {
NotFound,
QuotaExceeded { limit: usize },
}
struct MultiInterfaceService {
user_authenticated: bool,
items: Vec<String>,
}
#[zlink::service]
impl MultiInterfaceService {
#[zlink(interface = "org.example.auth")]
async fn authenticate(&mut self, password: String) -> Result<(), AuthError> {
if password == "secret" {
self.user_authenticated = true;
Ok(())
} else {
Err(AuthError::InvalidCredentials {
reason: "wrong password".to_string(),
})
}
}
async fn get_user_info(&self) -> Result<UserInfo, AuthError> {
if self.user_authenticated {
Ok(UserInfo {
name: "TestUser".to_string(),
})
} else {
Err(AuthError::NotAuthenticated)
}
}
#[zlink(interface = "org.example.storage")]
async fn item_count(&self) -> ItemCount {
ItemCount {
count: self.items.len(),
}
}
async fn get_item(&self, index: usize) -> Result<Item, StorageError> {
self.items
.get(index)
.map(|v| Item { value: v.clone() })
.ok_or(StorageError::NotFound)
}
async fn add_item(&mut self, item: String) -> Result<(), StorageError> {
if self.items.len() >= 2 {
Err(StorageError::QuotaExceeded { limit: 2 })
} else {
self.items.push(item);
Ok(())
}
}
}
#[zlink::proxy("org.example.auth")]
trait AuthProxy {
async fn authenticate(&mut self, password: String) -> zlink::Result<Result<(), AuthError>>;
async fn get_user_info(&mut self) -> zlink::Result<Result<UserInfo, AuthError>>;
}
#[zlink::proxy("org.example.storage")]
trait StorageProxy {
async fn item_count(&mut self) -> zlink::Result<Result<ItemCount, StorageError>>;
async fn get_item(&mut self, index: usize) -> zlink::Result<Result<Item, StorageError>>;
async fn add_item(&mut self, item: String) -> zlink::Result<Result<(), StorageError>>;
}