1mod event;
2mod request;
3
4use crate::server::HandleResult::{Break, Continue};
5use crate::server::event::EventSender;
6use crate::{SOCKET_NAME, read_deserialized, send_serialized};
7pub use event::ServerEvent;
8use interprocess::local_socket::traits::ListenerExt;
9use interprocess::local_socket::{
10 GenericNamespaced, ListenerNonblockingMode, ListenerOptions, ToNsName,
11};
12pub use request::{GetAgeIdentityResponse, Request, SetAgeIdentityResponse};
13use std::io::{BufReader, Read, Write};
14use std::sync::mpsc;
15use std::sync::mpsc::channel;
16use std::time::Duration;
17use std::{io, thread};
18
19#[derive(Debug, Default)]
20struct State {
21 age_identity: String,
22 age_pin: Option<String>,
23}
24
25pub fn run_with_idle_timeout(timeout: Duration) -> io::Result<()> {
26 let (sender, receiver) = channel();
27
28 thread::spawn(move || run(Some(sender)));
29 receiver
30 .recv_timeout(Duration::from_secs(1))
31 .map_err(|_| io::Error::other("Server did not start"))?;
32
33 loop {
34 match receiver.recv_timeout(timeout) {
35 Ok(ServerEvent::Stopped) | Err(_) => break,
36 Ok(ServerEvent::RequestHandled) => continue,
37 Ok(_) => continue,
38 };
39 }
40 Ok(())
41}
42
43pub fn run(event_sender: Option<mpsc::Sender<ServerEvent>>) -> io::Result<()> {
44 let mut state = State::default();
45
46 let socket_name = SOCKET_NAME.to_ns_name::<GenericNamespaced>()?;
47
48 let options = ListenerOptions::new()
49 .name(socket_name)
50 .nonblocking(ListenerNonblockingMode::Neither);
51 let listener = options.create_sync()?;
52
53 event_sender.send_server_event(ServerEvent::Started)?;
54 for result in listener.incoming() {
55 let mut con = BufReader::new(result?);
56
57 let request: Request = read_deserialized(&mut con)?;
58
59 let handle_result = handle_request(request, &mut state, &mut con)?;
60 event_sender.send_server_event(ServerEvent::RequestHandled)?;
61 if let Break = handle_result {
62 break;
63 }
64 }
65 event_sender.send_server_event(ServerEvent::Stopped)?;
66 Ok(())
67}
68
69#[derive(Debug, PartialEq)]
70enum HandleResult {
71 Break,
72 Continue,
73}
74
75fn handle_request<R: Read + Write>(
76 request: Request,
77 state: &mut State,
78 con: &mut BufReader<R>,
79) -> io::Result<HandleResult> {
80 match request {
81 Request::SetAgeIdentity { identity, pin } => {
82 state.age_identity = identity;
83 state.age_pin = pin;
84 Ok(Continue)
85 }
86 Request::GetAgeIdentity { pin } => {
87 let response = if state.age_identity.is_empty() {
88 GetAgeIdentityResponse::NotSet
89 } else if pin == state.age_pin {
90 GetAgeIdentityResponse::Ok {
91 identity: state.age_identity.clone(),
92 }
93 } else {
94 GetAgeIdentityResponse::WrongPin
95 };
96 send_serialized(&response, con.get_mut())?;
97 if let GetAgeIdentityResponse::WrongPin = response {
98 Ok(Break)
99 } else {
100 Ok(Continue)
101 }
102 }
103 Request::Shutdown => Ok(Break),
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use std::io::Cursor;
111
112 #[test]
113 fn test_serialization() {
114 let req = Request::SetAgeIdentity {
115 identity: "id".to_string(),
116 pin: Some("pin".to_string()),
117 };
118 let mut con = BufReader::new(Cursor::new(Vec::new()));
119 send_serialized(&req, con.get_mut()).unwrap();
120 con.get_mut().set_position(0);
121 let deserialized = read_deserialized(con.get_mut()).unwrap();
122 assert_eq!(req, deserialized);
123 }
124
125 #[test]
126 fn test_handle_get_age_empty() {
127 let mut state = State::default();
128 let mut con: BufReader<_> = BufReader::new(Cursor::new(Vec::new()));
129 let result =
130 handle_request(Request::GetAgeIdentity { pin: None }, &mut state, &mut con).unwrap();
131 assert_eq!(Continue, result);
132 con.get_mut().set_position(0);
133 let response: GetAgeIdentityResponse = read_deserialized(con.get_mut()).unwrap();
134 assert_eq!(GetAgeIdentityResponse::NotSet, response);
135 }
136
137 #[test]
138 fn test_handle_get_age() {
139 let mut state = State {
140 age_identity: "id".to_string(),
141 ..Default::default()
142 };
143 let mut con: BufReader<_> = BufReader::new(Cursor::new(Vec::new()));
144 let result =
145 handle_request(Request::GetAgeIdentity { pin: None }, &mut state, &mut con).unwrap();
146 assert_eq!(Continue, result);
147 con.get_mut().set_position(0);
148 let response: GetAgeIdentityResponse = read_deserialized(con.get_mut()).unwrap();
149 assert_eq!(
150 GetAgeIdentityResponse::Ok {
151 identity: "id".to_string()
152 },
153 response
154 );
155 }
156
157 #[test]
158 fn test_handle_get_age_password() {
159 let mut state = State {
160 age_identity: "id".to_string(),
161 age_pin: Some("pass".to_string()),
162 };
163 let mut con: BufReader<_> = BufReader::new(Cursor::new(Vec::new()));
164 let result = handle_request(
165 Request::GetAgeIdentity {
166 pin: Some("pass".to_string()),
167 },
168 &mut state,
169 &mut con,
170 )
171 .unwrap();
172 assert_eq!(Continue, result);
173 con.get_mut().set_position(0);
174 let response: GetAgeIdentityResponse = read_deserialized(con.get_mut()).unwrap();
175 assert_eq!(
176 GetAgeIdentityResponse::Ok {
177 identity: "id".to_string()
178 },
179 response
180 );
181 }
182
183 #[test]
184 fn test_handle_get_age_wrong_password() {
185 let mut state = State {
186 age_identity: "id".to_string(),
187 age_pin: Some("pass".to_string()),
188 };
189 let mut con: BufReader<_> = BufReader::new(Cursor::new(Vec::new()));
190 let result = handle_request(
191 Request::GetAgeIdentity {
192 pin: Some("wrong pass".to_string()),
193 },
194 &mut state,
195 &mut con,
196 )
197 .unwrap();
198 assert_eq!(Break, result);
199 con.get_mut().set_position(0);
200 let response: GetAgeIdentityResponse = read_deserialized(con.get_mut()).unwrap();
201 assert_eq!(GetAgeIdentityResponse::WrongPin, response)
202 }
203}