1use core::fmt;
2use std::str::FromStr;
3
4use http::StatusCode;
5use lazy_static::lazy_static;
6use regex::Regex;
7use serde::{
8 de::{Error, Visitor},
9 ser::SerializeSeq,
10 Deserialize, Serialize,
11};
12use thiserror::Error;
13use wildmatch::WildMatch;
14
15use crate::{
16 format::*,
17 impl_into_status_code,
18 state::{NodeKey, NodeType},
19};
20
21#[derive(Debug, Error)]
22#[error("invalid node target string")]
23pub struct NodeTargetError;
24
25impl_into_status_code!(NodeTargetError, |_| StatusCode::BAD_REQUEST);
26
27#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)]
30pub enum NodeTargets {
31 #[default]
32 None,
33 One(NodeTarget),
34 Many(Vec<NodeTarget>),
35}
36
37impl DataFormat for NodeTargets {
38 type Header = DataHeaderOf<NodeTarget>;
39 const LATEST_HEADER: Self::Header = NodeTarget::LATEST_HEADER;
40
41 fn write_data<W: std::io::prelude::Write>(
42 &self,
43 writer: &mut W,
44 ) -> Result<usize, DataWriteError> {
45 match self {
46 NodeTargets::None => vec![],
47 NodeTargets::One(target) => vec![target.clone()],
48 NodeTargets::Many(targets) => targets.clone(),
49 }
50 .write_data(writer)
51 }
52
53 fn read_data<R: std::io::prelude::Read>(
54 reader: &mut R,
55 header: &Self::Header,
56 ) -> Result<Self, DataReadError> {
57 let targets = Vec::<NodeTarget>::read_data(reader, header)?;
58 Ok(NodeTargets::from(targets))
59 }
60}
61
62impl<'de> Deserialize<'de> for NodeTargets {
63 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64 where
65 D: serde::Deserializer<'de>,
66 {
67 struct NodeTargetsVisitor;
68
69 impl<'de> Visitor<'de> for NodeTargetsVisitor {
70 type Value = NodeTargets;
71
72 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
73 formatter.write_str("one or more node targets")
74 }
75
76 fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
77 if v.contains(',') {
78 return Ok(NodeTargets::Many(
79 v.split(',')
80 .map(|s| NodeTarget::from_str(s.trim()).map_err(E::custom))
81 .collect::<Result<_, _>>()?,
82 ));
83 }
84 Ok(NodeTargets::One(FromStr::from_str(v).map_err(E::custom)?))
85 }
86
87 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
88 where
89 A: serde::de::SeqAccess<'de>,
90 {
91 let mut buf = vec![];
92
93 while let Some(elem) = seq.next_element()? {
94 buf.push(NodeTarget::from_str(elem).map_err(A::Error::custom)?);
95 }
96
97 Ok(if buf.is_empty() {
98 NodeTargets::None
99 } else {
100 NodeTargets::Many(buf)
101 })
102 }
103 }
104
105 deserializer.deserialize_any(NodeTargetsVisitor)
106 }
107}
108
109lazy_static! {
110 static ref NODE_TARGET_REGEX: Regex =
111 Regex::new(r"^(?P<ty>\*|client|validator|prover)\/(?P<id>[A-Za-z0-9\-*]+)(?:@(?P<ns>[A-Za-z0-9\-*]+))?$")
112 .unwrap();
113}
114
115impl Serialize for NodeTargets {
116 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117 where
118 S: serde::Serializer,
119 {
120 match self {
121 NodeTargets::None => serializer.serialize_seq(Some(0))?.end(),
122 NodeTargets::One(target) => serializer.serialize_str(&target.to_string()),
123 NodeTargets::Many(targets) => {
124 let mut seq = serializer.serialize_seq(Some(targets.len()))?;
125 for target in targets {
126 seq.serialize_element(&target.to_string())?;
127 }
128 seq.end()
129 }
130 }
131 }
132}
133
134impl fmt::Display for NodeTargets {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 match self {
137 NodeTargets::None => write!(f, ""),
138 NodeTargets::One(target) => write!(f, "{}", target),
139 NodeTargets::Many(targets) => {
140 let mut iter = targets.iter();
141 if let Some(target) = iter.next() {
142 write!(f, "{}", target)?;
143 for target in iter {
144 write!(f, ", {}", target)?;
145 }
146 }
147 Ok(())
148 }
149 }
150 }
151}
152
153impl NodeTargets {
154 pub const ALL: Self = Self::One(NodeTarget::ALL);
155
156 pub fn is_all(&self) -> bool {
157 if matches!(self, NodeTargets::One(NodeTarget::ALL)) {
158 return true;
159 }
160
161 if let NodeTargets::Many(targets) = self {
162 return targets.iter().any(|target| target == &NodeTarget::ALL);
163 }
164
165 false
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Hash, Eq)]
172pub struct NodeTarget {
173 pub ty: NodeTargetType,
174 pub id: NodeTargetId,
175 pub ns: NodeTargetNamespace,
176}
177
178impl FromStr for NodeTarget {
179 type Err = NodeTargetError;
180
181 fn from_str(s: &str) -> Result<Self, Self::Err> {
182 let captures = NODE_TARGET_REGEX.captures(s).ok_or(NodeTargetError)?;
183
184 let ty = match &captures["ty"] {
186 "*" => NodeTargetType::All,
187 "client" => NodeTargetType::One(NodeType::Client),
188 "validator" => NodeTargetType::One(NodeType::Validator),
189 "prover" => NodeTargetType::One(NodeType::Prover),
190 _ => unreachable!(),
191 };
192
193 let id = match &captures["id"] {
195 "*" => NodeTargetId::All,
197
198 id if id.contains('*') => NodeTargetId::WildcardPattern(WildMatch::new(id)),
200
201 id => NodeTargetId::Literal(id.into()),
203 };
204
205 let ns = match captures.name("ns") {
207 Some(id) if id.as_str() == "*" => NodeTargetNamespace::All,
209
210 Some(id) if id.as_str() == "local" => NodeTargetNamespace::Local,
212 None => NodeTargetNamespace::Local,
213
214 Some(id) => NodeTargetNamespace::Literal(id.as_str().into()),
216 };
217
218 Ok(Self { ty, id, ns })
219 }
220}
221
222impl fmt::Display for NodeTarget {
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 write!(
225 f,
226 "{}/{}{}",
227 match self.ty {
228 NodeTargetType::All => "*".to_owned(),
229 NodeTargetType::One(ty) => ty.to_string(),
230 },
231 match &self.id {
232 NodeTargetId::All => "*".to_owned(),
233 NodeTargetId::WildcardPattern(pattern) => pattern.to_string(),
234 NodeTargetId::Literal(id) => id.to_owned(),
235 },
236 match &self.ns {
237 NodeTargetNamespace::All => "@*".to_owned(),
238 NodeTargetNamespace::Local => "".to_owned(),
239 NodeTargetNamespace::Literal(ns) => format!("@{}", ns),
240 }
241 )
242 }
243}
244
245impl Serialize for NodeTarget {
246 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
247 where
248 S: serde::Serializer,
249 {
250 serializer.serialize_str(&self.to_string())
251 }
252}
253
254impl<'de> Deserialize<'de> for NodeTarget {
255 fn deserialize<D>(deserializer: D) -> Result<NodeTarget, D::Error>
256 where
257 D: serde::Deserializer<'de>,
258 {
259 let s = String::deserialize(deserializer)?;
260 NodeTarget::from_str(&s).map_err(D::Error::custom)
261 }
262}
263
264impl DataFormat for NodeTarget {
265 type Header = (u8, DataHeaderOf<NodeType>);
266 const LATEST_HEADER: Self::Header = (1, NodeType::LATEST_HEADER);
267
268 fn write_data<W: std::io::prelude::Write>(
269 &self,
270 writer: &mut W,
271 ) -> Result<usize, DataWriteError> {
272 let mut written = 0;
273 written += match self.ty {
274 NodeTargetType::All => 0u8.write_data(writer)?,
275 NodeTargetType::One(ty) => 1u8.write_data(writer)? + ty.write_data(writer)?,
276 };
277 written += match &self.id {
278 NodeTargetId::All => 0u8.write_data(writer)?,
279 NodeTargetId::WildcardPattern(pattern) => {
280 1u8.write_data(writer)? + pattern.to_string().write_data(writer)?
281 }
282 NodeTargetId::Literal(id) => 2u8.write_data(writer)? + id.write_data(writer)?,
283 };
284 written += match &self.ns {
285 NodeTargetNamespace::All => 0u8.write_data(writer)?,
286 NodeTargetNamespace::Local => 1u8.write_data(writer)?,
287 NodeTargetNamespace::Literal(ns) => 2u8.write_data(writer)? + ns.write_data(writer)?,
288 };
289
290 Ok(written)
291 }
292
293 fn read_data<R: std::io::prelude::Read>(
294 reader: &mut R,
295 header: &Self::Header,
296 ) -> Result<Self, DataReadError> {
297 if header.0 != Self::LATEST_HEADER.0 {
298 return Err(DataReadError::unsupported(
299 "NodeTarget",
300 Self::LATEST_HEADER.0,
301 header.0,
302 ));
303 }
304
305 let ty = match reader.read_data(&())? {
306 0u8 => NodeTargetType::All,
307 1u8 => NodeTargetType::One(NodeType::read_data(reader, &header.1)?),
308 n => {
309 return Err(DataReadError::Custom(format!(
310 "invalid NodeTarget type discriminant: {n}"
311 )))
312 }
313 };
314
315 let id = match reader.read_data(&())? {
316 0u8 => NodeTargetId::All,
317 1u8 => {
318 let pattern = String::read_data(reader, &())?;
319 NodeTargetId::WildcardPattern(WildMatch::new(&pattern))
320 }
321 2u8 => NodeTargetId::Literal(reader.read_data(&())?),
322 n => {
323 return Err(DataReadError::Custom(format!(
324 "invalid NodeTarget ID discriminant: {n}"
325 )))
326 }
327 };
328
329 let ns = match reader.read_data(&())? {
330 0u8 => NodeTargetNamespace::All,
331 1u8 => NodeTargetNamespace::Local,
332 2u8 => NodeTargetNamespace::Literal(reader.read_data(&())?),
333 n => {
334 return Err(DataReadError::Custom(format!(
335 "invalid NodeTarget namespace discriminant: {n}"
336 )))
337 }
338 };
339
340 Ok(Self { ty, id, ns })
341 }
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
345pub enum NodeTargetType {
346 All,
348 One(NodeType),
350}
351
352#[derive(Debug, Clone, PartialEq)]
353pub enum NodeTargetId {
354 All,
356 WildcardPattern(WildMatch),
358 Literal(String),
360}
361
362impl Eq for NodeTargetId {}
363
364impl std::hash::Hash for NodeTargetId {
365 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
366 match self {
367 NodeTargetId::All => "*".hash(state),
368 NodeTargetId::WildcardPattern(pattern) => pattern.to_string().hash(state),
369 NodeTargetId::Literal(id) => id.hash(state),
370 }
371 }
372}
373
374#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)]
375pub enum NodeTargetNamespace {
376 All,
378 Literal(String),
380 Local,
382}
383
384impl From<NodeKey> for NodeTarget {
385 fn from(value: NodeKey) -> Self {
386 Self {
387 ty: NodeTargetType::One(value.ty),
388 id: NodeTargetId::Literal(value.id),
389 ns: value
390 .ns
391 .map(NodeTargetNamespace::Literal)
392 .unwrap_or(NodeTargetNamespace::Local),
393 }
394 }
395}
396
397impl From<Vec<NodeTarget>> for NodeTargets {
398 fn from(nodes: Vec<NodeTarget>) -> Self {
399 match nodes.len() {
400 0 => Self::None,
401 1 => Self::One(nodes.into_iter().next().unwrap()),
402 _ => Self::Many(nodes),
403 }
404 }
405}
406
407impl NodeTarget {
408 pub const ALL: Self = Self {
409 ty: NodeTargetType::All,
410 id: NodeTargetId::All,
411 ns: NodeTargetNamespace::All,
412 };
413
414 pub fn matches(&self, key: &NodeKey) -> bool {
415 (match self.ty {
416 NodeTargetType::All => true,
417 NodeTargetType::One(ty) => ty == key.ty,
418 }) && (match &self.id {
419 NodeTargetId::All => true,
420 NodeTargetId::WildcardPattern(pattern) => pattern.matches(&key.id),
421 NodeTargetId::Literal(id) => &key.id == id,
422 }) && (match &self.ns {
423 NodeTargetNamespace::All => true,
424 NodeTargetNamespace::Local => key.ns.is_none() || key.ns == Some("local".into()),
425 NodeTargetNamespace::Literal(ns) => {
426 ns == "local" && key.ns.is_none()
427 || key.ns.as_ref().map_or(false, |key_ns| key_ns == ns)
428 }
429 })
430 }
431}
432
433impl NodeTargets {
434 pub fn is_empty(&self) -> bool {
435 if matches!(self, &NodeTargets::None) {
436 return true;
437 }
438
439 if let NodeTargets::Many(targets) = self {
440 return targets.is_empty();
441 }
442
443 false
444 }
445
446 pub fn matches(&self, key: &NodeKey) -> bool {
447 match self {
448 NodeTargets::None => false,
449 NodeTargets::One(target) => target.matches(key),
450 NodeTargets::Many(targets) => targets.iter().any(|target| target.matches(key)),
451 }
452 }
453}