playit_agent_core/agent_control/
connected_control.rs1use std::{net::SocketAddr, time::Duration};
2
3use message_encoding::MessageEncoding;
4use playit_agent_proto::{control_feed::ControlFeed, control_messages::{AgentRegistered, ControlRequest, ControlResponse, Ping, Pong}, raw_slice::RawSlice, rpc::ControlRpcMessage};
5
6use crate::utils::now_milli;
7
8use super::{errors::{ControlError, SetupError}, established_control::EstablishedControl, AuthResource, PacketIO};
9
10#[derive(Debug)]
11pub struct ConnectedControl<IO: PacketIO> {
12 pub(super) control_addr: SocketAddr,
13 pub(super) packet_io: IO,
14 pub(super) pong_latest: Pong,
15 pub(super) buffer: Vec<u8>,
16}
17
18impl<IO: PacketIO> ConnectedControl<IO> {
19 pub fn new(control_addr: SocketAddr, udp: IO, pong: Pong) -> Self {
20 ConnectedControl { control_addr, packet_io: udp, pong_latest: pong, buffer: Vec::with_capacity(1024) }
21 }
22
23 pub fn control_addr(&self) -> SocketAddr {
24 self.control_addr
25 }
26
27 pub fn pong(&self) -> Pong {
28 self.pong_latest.clone()
29 }
30
31 pub async fn auth_into_established<A: AuthResource>(mut self, auth: A) -> Result<EstablishedControl<A, IO>, SetupError> {
32 let registered = self.authenticate(&auth).await?;
33 Ok(self.into_established(auth, registered))
34 }
35
36 pub fn into_established<A: AuthResource>(self, auth: A, registered: AgentRegistered) -> EstablishedControl<A, IO> {
37 let pong = self.pong_latest.clone();
38
39 EstablishedControl {
40 auth,
41 conn: self,
42 pong_at_auth: pong,
43 registered,
44 current_ping: None,
45 clock_offset: 0,
46 force_expired: false,
47 }
48 }
49
50 pub fn reset_established<A: AuthResource>(self, established: &mut EstablishedControl<A, IO>, registered: AgentRegistered) {
51 established.registered = registered;
52 established.pong_at_auth = self.pong_latest.clone();
53 established.conn = self;
54 established.current_ping = None;
55 established.force_expired = false;
56 }
57
58 pub async fn authenticate<A: AuthResource>(&mut self, auth: &A) -> Result<AgentRegistered, SetupError> {
59 let auth_pong = self.pong_latest.clone();
60 let res = auth.authenticate(&auth_pong).await?;
61
62 let bytes = match hex::decode(&res.key) {
63 Ok(data) => data,
64 Err(_) => return Err(SetupError::FailedToDecodeSignedAgentRegisterHex),
65 };
66
67 let request_id = now_milli();
68
69 for _ in 0..5 {
70 self.send(&ControlRpcMessage {
71 request_id,
72 content: RawSlice(&bytes),
73 }).await?;
74
75 for _ in 0..5 {
76 let mesage = match tokio::time::timeout(Duration::from_millis(500), self.recv()).await {
77 Ok(Ok(msg)) => msg,
78 Ok(Err(error)) => {
79 tracing::error!(?error, "got error reading from socket");
80 break;
81 }
82 Err(_) => {
83 tracing::error!("timeout waiting for register response");
84 continue;
85 }
86 };
87
88 let response = match mesage {
89 ControlFeed::Response(response) if response.request_id == request_id => response,
90 other => {
91 tracing::error!(?other, "got unexpected response from register request");
92 continue;
93 }
94 };
95
96 return match response.content {
97 ControlResponse::AgentRegistered(registered) => Ok(registered),
98 ControlResponse::InvalidSignature => Err(SetupError::RegisterInvalidSignature),
99 ControlResponse::Unauthorized => {
100 let _ = self.send(&ControlRpcMessage {
102 request_id,
103 content: ControlRequest::Ping(Ping {
104 now: now_milli(),
105 current_ping: None,
106 session_id: None,
107 }),
108 }).await;
109
110 Err(SetupError::RegisterUnauthorized)
111 },
112 ControlResponse::Pong(pong) => {
113 if pong.client_addr != auth_pong.client_addr || pong.tunnel_addr != auth_pong.tunnel_addr {
114 Err(SetupError::AttemptingToAuthWithOldFlow)
115 } else {
116 continue;
117 }
118 }
119 ControlResponse::RequestQueued => {
120 tracing::info!("register queued, waiting 1s");
121 tokio::time::sleep(Duration::from_secs(1)).await;
122 break;
123 }
124 other => {
125 tracing::error!(?other, "expected AgentRegistered but got something else");
126 continue;
127 }
128 };
129 }
130 }
131
132 Err(SetupError::FailedToConnect)
133 }
134
135 pub async fn send<M: MessageEncoding>(&mut self, msg: &M) -> std::io::Result<()> {
136 self.buffer.clear();
137 msg.write_to(&mut self.buffer)?;
138 self.packet_io.send_to(&self.buffer, self.control_addr).await?;
139 Ok(())
140 }
141
142 pub async fn recv(&mut self) -> Result<ControlFeed, ControlError> {
143 self.buffer.resize(1024, 0);
144
145 let (bytes, remote) = self.packet_io.recv_from(&mut self.buffer).await?;
146 if remote != self.control_addr {
147 return Err(ControlError::InvalidRemote { expected: self.control_addr, got: remote });
148 }
149
150 let mut reader = &self.buffer[..bytes];
151 let feed = ControlFeed::read_from(&mut reader).map_err(|e| ControlError::FailedToReadControlFeed(e))?;
152
153 if let ControlFeed::Response(ControlRpcMessage { content: ControlResponse::Pong(pong), .. }) = &feed {
154 self.pong_latest = pong.clone();
155 }
156
157 Ok(feed)
158 }
159}
160