1use std::{net::SocketAddr, ops::Deref, sync::Arc};
17
18use async_trait::async_trait;
19use endhost_api_client::client::CrpcEndhostApiClient;
20use scion_sdk_reqwest_connect_rpc::{client::CrpcClientError, token_source::TokenSource};
21use url::Url;
22use x25519_dalek::PublicKey;
23
24use crate::{
25 crpc_api::api_service::{GET_SNAP_DATA_PLANE_ADDRESS, REGISTER_SNAPTUN_IDENTITY, SERVICE_PATH},
26 protobuf::anapaya::snap::v1::api_service as proto,
27};
28
29pub mod re_export {
31 pub use endhost_api_client::client::{CrpcEndhostApiClient, EndhostApiClient};
32 pub use scion_sdk_reqwest_connect_rpc::{client::CrpcClientError, token_source::*};
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct GetDataPlaneAddressResponse {
38 pub address: SocketAddr,
40 pub snap_tun_control_address: Option<SocketAddr>,
43 pub snap_static_x25519: Option<PublicKey>,
46}
47
48#[async_trait]
50pub trait ControlPlaneApi: Send + Sync {
51 async fn get_data_plane_address(&self) -> Result<GetDataPlaneAddressResponse, CrpcClientError>;
53
54 async fn register_snaptun_identity(
56 &self,
57 initiator_identity: PublicKey,
58 psk_share: Option<[u8; 32]>,
59 ) -> Result<Option<[u8; 32]>, CrpcClientError>;
60}
61
62pub struct CrpcSnapControlClient {
64 client: CrpcEndhostApiClient,
65}
66
67impl Deref for CrpcSnapControlClient {
68 type Target = CrpcEndhostApiClient;
69
70 fn deref(&self) -> &Self::Target {
71 &self.client
72 }
73}
74
75impl CrpcSnapControlClient {
76 pub fn new(base_url: &Url) -> anyhow::Result<Self> {
78 let client = CrpcEndhostApiClient::new(base_url)?;
79 Ok(Self { client })
80 }
81
82 pub fn new_with_client(base_url: &Url, client: reqwest::Client) -> anyhow::Result<Self> {
84 Ok(Self {
85 client: CrpcEndhostApiClient::new_with_client(base_url, client)?,
86 })
87 }
88
89 pub fn use_token_source(&mut self, token_source: Arc<dyn TokenSource>) -> &mut Self {
91 self.client.use_token_source(token_source);
92 self
93 }
94}
95
96#[async_trait]
97impl ControlPlaneApi for CrpcSnapControlClient {
98 async fn get_data_plane_address(&self) -> Result<GetDataPlaneAddressResponse, CrpcClientError> {
99 let res: proto::GetSnapDataPlaneResponse = self
100 .client
101 .unary_request::<proto::GetSnapDataPlaneRequest, proto::GetSnapDataPlaneResponse>(
102 &format!("{SERVICE_PATH}{GET_SNAP_DATA_PLANE_ADDRESS}"),
103 proto::GetSnapDataPlaneRequest::default(),
104 )
105 .await?;
106 let address = res.address.parse().map_err(|e: std::net::AddrParseError| {
107 CrpcClientError::DecodeError {
108 context: "parsing data plane address".into(),
109 source: e.into(),
110 body: None,
111 }
112 })?;
113
114 let snap_tun_control_address = res
115 .snap_tun_control_address
116 .map(|address| {
117 address.parse().map_err(|e: std::net::AddrParseError| {
118 CrpcClientError::DecodeError {
119 context: "parsing server control address".into(),
120 source: e.into(),
121 body: None,
122 }
123 })
124 })
125 .transpose()?;
126 let snap_static_x25519 = res
127 .snap_static_x25519
128 .map(|key| {
129 let key_bytes: [u8; 32] =
130 key.as_slice()
131 .try_into()
132 .map_err(|e: std::array::TryFromSliceError| {
133 CrpcClientError::DecodeError {
134 context: "server static identity is not 32 bytes".into(),
135 source: e.into(),
136 body: None,
137 }
138 })?;
139 Ok::<_, CrpcClientError>(PublicKey::from(key_bytes))
140 })
141 .transpose()?;
142 Ok(GetDataPlaneAddressResponse {
143 address,
144 snap_tun_control_address,
145 snap_static_x25519,
146 })
147 }
148
149 async fn register_snaptun_identity(
150 &self,
151 initiator_identity: PublicKey,
152 psk_share: Option<[u8; 32]>,
153 ) -> Result<Option<[u8; 32]>, CrpcClientError> {
154 let res = self.client.unary_request::<proto::RegisterSnapTunIdentityRequest, proto::RegisterSnapTunIdentityResponse>(
155 &format!("{SERVICE_PATH}{REGISTER_SNAPTUN_IDENTITY}"),
156 proto::RegisterSnapTunIdentityRequest { initiator_static_x25519: initiator_identity.to_bytes().to_vec(), psk_share: psk_share.unwrap_or([0u8;32]).to_vec() },
157 ).await?;
158 let psk_share = if res.psk_share.as_slice() == [0u8; 32] {
159 None
160 } else {
161 Some(res.psk_share.as_slice().try_into().map_err(
162 |e: std::array::TryFromSliceError| {
163 CrpcClientError::DecodeError {
164 context: "psk share is not 32 bytes".into(),
165 source: e.into(),
166 body: None,
167 }
168 },
169 )?)
170 };
171 Ok(psk_share)
172 }
173}