xtee-utee 0.2.0

TEE internal API bindings for xTEE Trusted Applications.
Documentation
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2026 KylinSoft Co., Ltd. <https://www.kylinos.cn/>
// See LICENSES for license details.

#![allow(non_camel_case_types)]

use std::{
    collections::HashMap,
    io::{Read, Write},
    os::unix::net::{UnixListener, UnixStream},
    path::PathBuf,
    sync::{
        atomic::{AtomicU32, Ordering},
        Arc,
    },
    thread,
};

use crossbeam_channel::{unbounded, Receiver, Sender};
use rust_utee::tee_api_defines::TEE_SUCCESS;

use crate::error::{Error, Result};
use teec_protocol::{CaAuthInfo, Parameters, TARequest, TEE_Request, TEE_Response};

const SERVER_SOCKET_PATH: &str = "/tmp/server.sock";

/// Trait representing a Trusted Application (TA).
pub trait TrustedApplication: Send + Sync + 'static {
    /// User-defined session context type.
    type SessionContext: Send;

    /// Create a new TA instance.
    fn create(&self) -> Result<()>;

    /// ACL check for the TA using CA authentication info from `OpenSession`.
    fn acl_check(&self, _ca_auth_info: Option<&CaAuthInfo>) -> Result<()> {
        Ok(())
    }

    /// Open a session with the TA.
    fn open_session(&self, params: &mut Parameters) -> Result<Self::SessionContext>;

    /// Close the session with the TA.
    fn close_session(&self, ctx: &mut Self::SessionContext) -> Result<()>;

    /// Destroy the TA instance.
    fn destroy(&self) -> Result<()>;

    /// Invoke a command on the TA.
    fn invoke_command(
        &self,
        cmd_id: u32,
        params: &mut Parameters,
        ctx: &mut Self::SessionContext,
    ) -> Result<()>;
}

/// Manager for a Trusted Application (TA).
pub struct TAManager<T: TrustedApplication> {
    ta: Arc<T>,
    uuid: String,
    sessions: HashMap<u32, Sender<SessionMessage>>,
    session_id: AtomicU32,
}

impl<T: TrustedApplication> TAManager<T> {
    pub fn new(ta: T, uuid: &str) -> Self {
        Self {
            ta: Arc::new(ta),
            uuid: uuid.to_string(),
            sessions: HashMap::new(),
            session_id: AtomicU32::new(1),
        }
    }

    pub fn run_ta(&mut self) -> anyhow::Result<()> {
        self.ta.create()?;
        let _stream = self.register_ta()?;
        self.handle_ca_request(self.ta.clone())?;
        Ok(())
    }

    // Register the TA with the TA Manager.
    fn register_ta(&self) -> anyhow::Result<UnixStream> {
        let mut stream = UnixStream::connect(SERVER_SOCKET_PATH)?;

        let req = TARequest::Register {
            uuid: self.uuid.clone(),
        };
        let data = postcard::to_allocvec(&req)?;
        stream.write_all(&data)?;
        println!("TA registered with UUID: {}", self.uuid);

        Ok(stream)
    }

    // Handle requests from the Client Application (CA).
    fn handle_ca_request(&mut self, ta: Arc<T>) -> anyhow::Result<()> {
        let path = PathBuf::from(format!("/tmp/{}.sock", self.uuid));
        let _ = std::fs::remove_file(path.clone());

        let listener = UnixListener::bind(path.clone())?;
        println!("TA listening on socket: {:?}", path);

        for stream in listener.incoming() {
            println!("Received connection from CA");
            let mut stream = stream?;

            let mut len_buf = [0u8; 4];
            stream.read_exact(&mut len_buf)?;
            let len = u32::from_ne_bytes(len_buf) as usize;
            let mut buf = vec![0u8; len];
            stream.read_exact(&mut buf)?;

            let req: TEE_Request = postcard::from_bytes(&buf)?;
            match &req {
                TEE_Request::OpenSession { .. } => {
                    self.handle_open_session(stream, ta.clone(), req)?
                }
                TEE_Request::CloseSession { .. } => self.handle_close_session(stream, req)?,
                TEE_Request::InvokeCommand { .. } => self.handle_invoke_command(stream, req)?,
                TEE_Request::RequestCancellation { .. } => todo!(),
            }
        }

        Ok(())
    }

    fn handle_open_session(
        &mut self,
        mut stream: UnixStream,
        ta: Arc<T>,
        req: TEE_Request,
    ) -> anyhow::Result<()> {
        // 从请求中提取参数
        let (mut params, ca_auth_info) = match req {
            TEE_Request::OpenSession {
                params,
                ca_auth_info,
                ..
            } => (params, ca_auth_info),
            _ => return Err(anyhow::anyhow!("Invalid request type for open_session")),
        };

        let resp = if let Err(e) = ta.acl_check(ca_auth_info.as_ref()) {
            println!("ACL check failed: {:?}", e);
            TEE_Response::OpenSession {
                session_id: 0,
                result: e.into(),
            }
        } else {
            let session_id = self.next_session_id();
            println!("Opening session with ID: {}", session_id);

            match ta.open_session(&mut params) {
                Ok(ctx) => {
                    println!("Session {} opened successfully", session_id);
                    let (tx, rx) = unbounded();
                    self.sessions.insert(session_id, tx);
                    thread::spawn(move || {
                        session_thread(ta.clone(), ctx, rx);
                    });
                    TEE_Response::OpenSession {
                        session_id,
                        result: TEE_SUCCESS,
                    }
                }
                Err(e) => {
                    println!("Failed to open session {}: {:?}", session_id, e);
                    TEE_Response::OpenSession {
                        session_id,
                        result: e.into(),
                    }
                }
            }
        };

        let resp_data = postcard::to_allocvec(&resp)?;
        let mut message = Vec::with_capacity(4 + resp_data.len());
        message.extend_from_slice(&(resp_data.len() as u32).to_ne_bytes());
        message.extend_from_slice(&resp_data);
        stream.write_all(&message)?;

        Ok(())
    }

    fn handle_close_session(
        &mut self,
        mut stream: UnixStream,
        req: TEE_Request,
    ) -> anyhow::Result<()> {
        // 从请求中提取 session_id
        let session_id = match req {
            TEE_Request::CloseSession { session_id } => session_id,
            _ => return Err(anyhow::anyhow!("Invalid request type for close_session")),
        };

        println!("Closing session with ID: {}", session_id);

        let resp = match self.sessions.get(&session_id) {
            Some(tx) => {
                let (resp_tx, resp_rx) = unbounded();
                tx.send(SessionMessage::Close { resp_tx })?;
                resp_rx.recv()?
            }
            None => {
                println!("Session {} not found", session_id);
                TEE_Response::CloseSession {
                    result: Error::ItemNotFound.into(),
                }
            }
        };

        let resp_data = postcard::to_allocvec(&resp)?;
        let mut message = Vec::with_capacity(4 + resp_data.len());
        message.extend_from_slice(&(resp_data.len() as u32).to_ne_bytes());
        message.extend_from_slice(&resp_data);
        stream.write_all(&message)?;

        Ok(())
    }

    fn handle_invoke_command(
        &mut self,
        mut stream: UnixStream,
        req: TEE_Request,
    ) -> anyhow::Result<()> {
        // 从请求中提取参数
        let (session_id, cmd_id, params) = match req {
            TEE_Request::InvokeCommand {
                session_id,
                cmd_id,
                params,
            } => (session_id, cmd_id, params),
            _ => return Err(anyhow::anyhow!("Invalid request type for invoke_command")),
        };

        println!("Invoking command {} on session {}", cmd_id, session_id);

        let resp = match self.sessions.get(&session_id) {
            Some(tx) => {
                let (resp_tx, resp_rx) = unbounded();
                tx.send(SessionMessage::Invoke {
                    cmd_id,
                    params,
                    resp_tx,
                })?;
                resp_rx.recv()?
            }
            None => {
                println!("Session {} not found", session_id);
                TEE_Response::InvokeCommand {
                    params,
                    result: Error::ItemNotFound.into(),
                }
            }
        };

        let resp_data = postcard::to_allocvec(&resp)?;
        let mut message = Vec::with_capacity(4 + resp_data.len());
        message.extend_from_slice(&(resp_data.len() as u32).to_ne_bytes());
        message.extend_from_slice(&resp_data);
        stream.write_all(&message)?;

        Ok(())
    }

    fn next_session_id(&self) -> u32 {
        self.session_id.fetch_add(1, Ordering::SeqCst)
    }
}

// Messages sent to session threads.
enum SessionMessage {
    Invoke {
        cmd_id: u32,
        params: Parameters,
        resp_tx: Sender<TEE_Response>,
    },
    Close {
        resp_tx: Sender<TEE_Response>,
    },
}

// Thread function to handle a TA session.
fn session_thread<T: TrustedApplication>(
    ta: Arc<T>,
    mut ctx: T::SessionContext,
    rx: Receiver<SessionMessage>,
) {
    for msg in rx.iter() {
        match msg {
            SessionMessage::Invoke {
                cmd_id,
                mut params,
                resp_tx,
            } => {
                let resp = match ta.invoke_command(cmd_id, &mut params, &mut ctx) {
                    Ok(_) => TEE_Response::InvokeCommand {
                        params,
                        result: TEE_SUCCESS,
                    },
                    Err(e) => TEE_Response::InvokeCommand {
                        params,
                        result: e.into(),
                    },
                };
                let _ = resp_tx.send(resp);
            }
            SessionMessage::Close { resp_tx } => {
                let resp = match ta.close_session(&mut ctx) {
                    Ok(_) => TEE_Response::CloseSession {
                        result: TEE_SUCCESS,
                    },
                    Err(e) => TEE_Response::CloseSession { result: e.into() },
                };
                let _ = resp_tx.send(resp);
                break;
            }
        }
    }
}