#![allow(non_camel_case_types)]
use std::{
collections::HashMap,
io::{Read, Write},
os::unix::net::{UnixListener, UnixStream},
path::PathBuf,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
thread,
};
use crossbeam_channel::{unbounded, Receiver, Sender};
use rust_utee::tee_api_defines::TEE_SUCCESS;
use crate::error::{Error, Result};
use teec_protocol::{CaAuthInfo, Parameters, TARequest, TEE_Request, TEE_Response};
const SERVER_SOCKET_PATH: &str = "/tmp/server.sock";
pub trait TrustedApplication: Send + Sync + 'static {
type SessionContext: Send;
fn create(&self) -> Result<()>;
fn acl_check(&self, _ca_auth_info: Option<&CaAuthInfo>) -> Result<()> {
Ok(())
}
fn open_session(&self, params: &mut Parameters) -> Result<Self::SessionContext>;
fn close_session(&self, ctx: &mut Self::SessionContext) -> Result<()>;
fn destroy(&self) -> Result<()>;
fn invoke_command(
&self,
cmd_id: u32,
params: &mut Parameters,
ctx: &mut Self::SessionContext,
) -> Result<()>;
}
pub struct TAManager<T: TrustedApplication> {
ta: Arc<T>,
uuid: String,
sessions: HashMap<u32, Sender<SessionMessage>>,
session_id: AtomicU32,
}
impl<T: TrustedApplication> TAManager<T> {
pub fn new(ta: T, uuid: &str) -> Self {
Self {
ta: Arc::new(ta),
uuid: uuid.to_string(),
sessions: HashMap::new(),
session_id: AtomicU32::new(1),
}
}
pub fn run_ta(&mut self) -> anyhow::Result<()> {
self.ta.create()?;
let _stream = self.register_ta()?;
self.handle_ca_request(self.ta.clone())?;
Ok(())
}
fn register_ta(&self) -> anyhow::Result<UnixStream> {
let mut stream = UnixStream::connect(SERVER_SOCKET_PATH)?;
let req = TARequest::Register {
uuid: self.uuid.clone(),
};
let data = postcard::to_allocvec(&req)?;
stream.write_all(&data)?;
println!("TA registered with UUID: {}", self.uuid);
Ok(stream)
}
fn handle_ca_request(&mut self, ta: Arc<T>) -> anyhow::Result<()> {
let path = PathBuf::from(format!("/tmp/{}.sock", self.uuid));
let _ = std::fs::remove_file(path.clone());
let listener = UnixListener::bind(path.clone())?;
println!("TA listening on socket: {:?}", path);
for stream in listener.incoming() {
println!("Received connection from CA");
let mut stream = stream?;
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf)?;
let len = u32::from_ne_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf)?;
let req: TEE_Request = postcard::from_bytes(&buf)?;
match &req {
TEE_Request::OpenSession { .. } => {
self.handle_open_session(stream, ta.clone(), req)?
}
TEE_Request::CloseSession { .. } => self.handle_close_session(stream, req)?,
TEE_Request::InvokeCommand { .. } => self.handle_invoke_command(stream, req)?,
TEE_Request::RequestCancellation { .. } => todo!(),
}
}
Ok(())
}
fn handle_open_session(
&mut self,
mut stream: UnixStream,
ta: Arc<T>,
req: TEE_Request,
) -> anyhow::Result<()> {
let (mut params, ca_auth_info) = match req {
TEE_Request::OpenSession {
params,
ca_auth_info,
..
} => (params, ca_auth_info),
_ => return Err(anyhow::anyhow!("Invalid request type for open_session")),
};
let resp = if let Err(e) = ta.acl_check(ca_auth_info.as_ref()) {
println!("ACL check failed: {:?}", e);
TEE_Response::OpenSession {
session_id: 0,
result: e.into(),
}
} else {
let session_id = self.next_session_id();
println!("Opening session with ID: {}", session_id);
match ta.open_session(&mut params) {
Ok(ctx) => {
println!("Session {} opened successfully", session_id);
let (tx, rx) = unbounded();
self.sessions.insert(session_id, tx);
thread::spawn(move || {
session_thread(ta.clone(), ctx, rx);
});
TEE_Response::OpenSession {
session_id,
result: TEE_SUCCESS,
}
}
Err(e) => {
println!("Failed to open session {}: {:?}", session_id, e);
TEE_Response::OpenSession {
session_id,
result: e.into(),
}
}
}
};
let resp_data = postcard::to_allocvec(&resp)?;
let mut message = Vec::with_capacity(4 + resp_data.len());
message.extend_from_slice(&(resp_data.len() as u32).to_ne_bytes());
message.extend_from_slice(&resp_data);
stream.write_all(&message)?;
Ok(())
}
fn handle_close_session(
&mut self,
mut stream: UnixStream,
req: TEE_Request,
) -> anyhow::Result<()> {
let session_id = match req {
TEE_Request::CloseSession { session_id } => session_id,
_ => return Err(anyhow::anyhow!("Invalid request type for close_session")),
};
println!("Closing session with ID: {}", session_id);
let resp = match self.sessions.get(&session_id) {
Some(tx) => {
let (resp_tx, resp_rx) = unbounded();
tx.send(SessionMessage::Close { resp_tx })?;
resp_rx.recv()?
}
None => {
println!("Session {} not found", session_id);
TEE_Response::CloseSession {
result: Error::ItemNotFound.into(),
}
}
};
let resp_data = postcard::to_allocvec(&resp)?;
let mut message = Vec::with_capacity(4 + resp_data.len());
message.extend_from_slice(&(resp_data.len() as u32).to_ne_bytes());
message.extend_from_slice(&resp_data);
stream.write_all(&message)?;
Ok(())
}
fn handle_invoke_command(
&mut self,
mut stream: UnixStream,
req: TEE_Request,
) -> anyhow::Result<()> {
let (session_id, cmd_id, params) = match req {
TEE_Request::InvokeCommand {
session_id,
cmd_id,
params,
} => (session_id, cmd_id, params),
_ => return Err(anyhow::anyhow!("Invalid request type for invoke_command")),
};
println!("Invoking command {} on session {}", cmd_id, session_id);
let resp = match self.sessions.get(&session_id) {
Some(tx) => {
let (resp_tx, resp_rx) = unbounded();
tx.send(SessionMessage::Invoke {
cmd_id,
params,
resp_tx,
})?;
resp_rx.recv()?
}
None => {
println!("Session {} not found", session_id);
TEE_Response::InvokeCommand {
params,
result: Error::ItemNotFound.into(),
}
}
};
let resp_data = postcard::to_allocvec(&resp)?;
let mut message = Vec::with_capacity(4 + resp_data.len());
message.extend_from_slice(&(resp_data.len() as u32).to_ne_bytes());
message.extend_from_slice(&resp_data);
stream.write_all(&message)?;
Ok(())
}
fn next_session_id(&self) -> u32 {
self.session_id.fetch_add(1, Ordering::SeqCst)
}
}
enum SessionMessage {
Invoke {
cmd_id: u32,
params: Parameters,
resp_tx: Sender<TEE_Response>,
},
Close {
resp_tx: Sender<TEE_Response>,
},
}
fn session_thread<T: TrustedApplication>(
ta: Arc<T>,
mut ctx: T::SessionContext,
rx: Receiver<SessionMessage>,
) {
for msg in rx.iter() {
match msg {
SessionMessage::Invoke {
cmd_id,
mut params,
resp_tx,
} => {
let resp = match ta.invoke_command(cmd_id, &mut params, &mut ctx) {
Ok(_) => TEE_Response::InvokeCommand {
params,
result: TEE_SUCCESS,
},
Err(e) => TEE_Response::InvokeCommand {
params,
result: e.into(),
},
};
let _ = resp_tx.send(resp);
}
SessionMessage::Close { resp_tx } => {
let resp = match ta.close_session(&mut ctx) {
Ok(_) => TEE_Response::CloseSession {
result: TEE_SUCCESS,
},
Err(e) => TEE_Response::CloseSession { result: e.into() },
};
let _ = resp_tx.send(resp);
break;
}
}
}
}