entrust_agent/
server.rs

1mod event;
2mod request;
3
4use crate::server::HandleResult::{Break, Continue};
5use crate::server::event::EventSender;
6use crate::{SOCKET_NAME, read_deserialized, send_serialized};
7pub use event::ServerEvent;
8use interprocess::local_socket::traits::ListenerExt;
9use interprocess::local_socket::{
10    GenericNamespaced, ListenerNonblockingMode, ListenerOptions, ToNsName,
11};
12pub use request::{GetAgeIdentityResponse, Request, SetAgeIdentityResponse};
13use std::io::{BufReader, Read, Write};
14use std::sync::mpsc;
15use std::sync::mpsc::channel;
16use std::time::Duration;
17use std::{io, thread};
18
19#[derive(Debug, Default)]
20struct State {
21    age_identity: String,
22    age_pin: Option<String>,
23}
24
25pub fn run_with_idle_timeout(timeout: Duration) -> io::Result<()> {
26    let (sender, receiver) = channel();
27
28    thread::spawn(move || run(Some(sender)));
29    receiver
30        .recv_timeout(Duration::from_secs(1))
31        .map_err(|_| io::Error::other("Server did not start"))?;
32
33    loop {
34        match receiver.recv_timeout(timeout) {
35            Ok(ServerEvent::Stopped) | Err(_) => break,
36            Ok(ServerEvent::RequestHandled) => continue,
37            Ok(_) => continue,
38        };
39    }
40    Ok(())
41}
42
43pub fn run(event_sender: Option<mpsc::Sender<ServerEvent>>) -> io::Result<()> {
44    let mut state = State::default();
45
46    let socket_name = SOCKET_NAME.to_ns_name::<GenericNamespaced>()?;
47
48    let options = ListenerOptions::new()
49        .name(socket_name)
50        .nonblocking(ListenerNonblockingMode::Neither);
51    let listener = options.create_sync()?;
52
53    event_sender.send_server_event(ServerEvent::Started)?;
54    for result in listener.incoming() {
55        let mut con = BufReader::new(result?);
56
57        let request: Request = read_deserialized(&mut con)?;
58
59        let handle_result = handle_request(request, &mut state, &mut con)?;
60        event_sender.send_server_event(ServerEvent::RequestHandled)?;
61        if let Break = handle_result {
62            break;
63        }
64    }
65    event_sender.send_server_event(ServerEvent::Stopped)?;
66    Ok(())
67}
68
69#[derive(Debug, PartialEq)]
70enum HandleResult {
71    Break,
72    Continue,
73}
74
75fn handle_request<R: Read + Write>(
76    request: Request,
77    state: &mut State,
78    con: &mut BufReader<R>,
79) -> io::Result<HandleResult> {
80    match request {
81        Request::SetAgeIdentity { identity, pin } => {
82            state.age_identity = identity;
83            state.age_pin = pin;
84            Ok(Continue)
85        }
86        Request::GetAgeIdentity { pin } => {
87            let response = if state.age_identity.is_empty() {
88                GetAgeIdentityResponse::NotSet
89            } else if pin == state.age_pin {
90                GetAgeIdentityResponse::Ok {
91                    identity: state.age_identity.clone(),
92                }
93            } else {
94                GetAgeIdentityResponse::WrongPin
95            };
96            send_serialized(&response, con.get_mut())?;
97            if let GetAgeIdentityResponse::WrongPin = response {
98                Ok(Break)
99            } else {
100                Ok(Continue)
101            }
102        }
103        Request::Shutdown => Ok(Break),
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use std::io::Cursor;
111
112    #[test]
113    fn test_serialization() {
114        let req = Request::SetAgeIdentity {
115            identity: "id".to_string(),
116            pin: Some("pin".to_string()),
117        };
118        let mut con = BufReader::new(Cursor::new(Vec::new()));
119        send_serialized(&req, con.get_mut()).unwrap();
120        con.get_mut().set_position(0);
121        let deserialized = read_deserialized(con.get_mut()).unwrap();
122        assert_eq!(req, deserialized);
123    }
124
125    #[test]
126    fn test_handle_get_age_empty() {
127        let mut state = State::default();
128        let mut con: BufReader<_> = BufReader::new(Cursor::new(Vec::new()));
129        let result =
130            handle_request(Request::GetAgeIdentity { pin: None }, &mut state, &mut con).unwrap();
131        assert_eq!(Continue, result);
132        con.get_mut().set_position(0);
133        let response: GetAgeIdentityResponse = read_deserialized(con.get_mut()).unwrap();
134        assert_eq!(GetAgeIdentityResponse::NotSet, response);
135    }
136
137    #[test]
138    fn test_handle_get_age() {
139        let mut state = State {
140            age_identity: "id".to_string(),
141            ..Default::default()
142        };
143        let mut con: BufReader<_> = BufReader::new(Cursor::new(Vec::new()));
144        let result =
145            handle_request(Request::GetAgeIdentity { pin: None }, &mut state, &mut con).unwrap();
146        assert_eq!(Continue, result);
147        con.get_mut().set_position(0);
148        let response: GetAgeIdentityResponse = read_deserialized(con.get_mut()).unwrap();
149        assert_eq!(
150            GetAgeIdentityResponse::Ok {
151                identity: "id".to_string()
152            },
153            response
154        );
155    }
156
157    #[test]
158    fn test_handle_get_age_password() {
159        let mut state = State {
160            age_identity: "id".to_string(),
161            age_pin: Some("pass".to_string()),
162        };
163        let mut con: BufReader<_> = BufReader::new(Cursor::new(Vec::new()));
164        let result = handle_request(
165            Request::GetAgeIdentity {
166                pin: Some("pass".to_string()),
167            },
168            &mut state,
169            &mut con,
170        )
171        .unwrap();
172        assert_eq!(Continue, result);
173        con.get_mut().set_position(0);
174        let response: GetAgeIdentityResponse = read_deserialized(con.get_mut()).unwrap();
175        assert_eq!(
176            GetAgeIdentityResponse::Ok {
177                identity: "id".to_string()
178            },
179            response
180        );
181    }
182
183    #[test]
184    fn test_handle_get_age_wrong_password() {
185        let mut state = State {
186            age_identity: "id".to_string(),
187            age_pin: Some("pass".to_string()),
188        };
189        let mut con: BufReader<_> = BufReader::new(Cursor::new(Vec::new()));
190        let result = handle_request(
191            Request::GetAgeIdentity {
192                pin: Some("wrong pass".to_string()),
193            },
194            &mut state,
195            &mut con,
196        )
197        .unwrap();
198        assert_eq!(Break, result);
199        con.get_mut().set_position(0);
200        let response: GetAgeIdentityResponse = read_deserialized(con.get_mut()).unwrap();
201        assert_eq!(GetAgeIdentityResponse::WrongPin, response)
202    }
203}