1use serde::{Deserialize, Serialize};
4
5use crate::Error;
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9pub struct Requirement {
10 pub name: String,
12 pub specifier: Option<String>,
14 pub extras: Vec<String>,
16 pub marker: Option<String>,
18 pub url: Option<String>,
20}
21
22impl Requirement {
23 pub fn new(name: impl Into<String>) -> Self {
25 Self {
26 name: name.into(),
27 specifier: None,
28 extras: vec![],
29 marker: None,
30 url: None,
31 }
32 }
33
34 pub fn parse(s: &str) -> Result<Self, Error> {
43 let s = s.trim();
44 if s.is_empty() {
45 return Err(Error::InvalidDependency(s.to_string()));
46 }
47
48 let (main_part, marker) = if let Some(semi_pos) = s.find(';') {
50 let marker = s[semi_pos + 1..].trim().to_string();
51 let main = s[..semi_pos].trim();
52 (
53 main,
54 if marker.is_empty() {
55 None
56 } else {
57 Some(marker)
58 },
59 )
60 } else {
61 (s, None)
62 };
63
64 let mut name = String::new();
66 let mut extras = Vec::new();
67 let mut specifier = None;
68 let mut chars = main_part.chars().peekable();
69
70 while let Some(&c) = chars.peek() {
72 if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' {
73 name.push(c);
74 chars.next();
75 } else {
76 break;
77 }
78 }
79
80 if name.is_empty() {
81 return Err(Error::InvalidDependency(s.to_string()));
82 }
83
84 if chars.peek() == Some(&'[') {
86 chars.next(); let mut extra = String::new();
88 while let Some(&c) = chars.peek() {
89 if c == ']' {
90 chars.next();
91 if !extra.is_empty() {
92 extras.push(extra.trim().to_string());
93 }
94 break;
95 } else if c == ',' {
96 chars.next();
97 if !extra.is_empty() {
98 extras.push(extra.trim().to_string());
99 extra = String::new();
100 }
101 } else {
102 extra.push(c);
103 chars.next();
104 }
105 }
106 }
107
108 while chars.peek() == Some(&' ') {
110 chars.next();
111 }
112
113 let remaining: String = chars.collect();
115 if !remaining.is_empty() {
116 specifier = Some(remaining.trim().to_string());
117 }
118
119 Ok(Self {
120 name,
121 specifier,
122 extras,
123 marker,
124 url: None,
125 })
126 }
127
128 pub fn with_extra(mut self, extra: impl Into<String>) -> Self {
130 self.extras.push(extra.into());
131 self
132 }
133
134 pub fn with_specifier(mut self, specifier: impl Into<String>) -> Self {
136 self.specifier = Some(specifier.into());
137 self
138 }
139
140 pub fn with_marker(mut self, marker: impl Into<String>) -> Self {
142 self.marker = Some(marker.into());
143 self
144 }
145}
146
147impl std::fmt::Display for Requirement {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 write!(f, "{}", self.name)?;
150
151 if !self.extras.is_empty() {
152 write!(f, "[{}]", self.extras.join(","))?;
153 }
154
155 if let Some(ref spec) = self.specifier {
156 write!(f, "{}", spec)?;
157 }
158
159 if let Some(ref marker) = self.marker {
160 write!(f, " ; {}", marker)?;
161 }
162
163 Ok(())
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_parse_simple() {
173 let req = Requirement::parse("requests").unwrap();
174 assert_eq!(req.name, "requests");
175 assert!(req.specifier.is_none());
176 assert!(req.marker.is_none());
177 }
178
179 #[test]
180 fn test_parse_with_version() {
181 let req = Requirement::parse("requests>=2.0").unwrap();
182 assert_eq!(req.name, "requests");
183 assert_eq!(req.specifier, Some(">=2.0".to_string()));
184 }
185
186 #[test]
187 fn test_parse_with_extras() {
188 let req = Requirement::parse("requests[security,socks]>=2.0").unwrap();
189 assert_eq!(req.name, "requests");
190 assert_eq!(req.extras, vec!["security", "socks"]);
191 assert_eq!(req.specifier, Some(">=2.0".to_string()));
192 }
193
194 #[test]
195 fn test_parse_with_marker() {
196 let req = Requirement::parse("requests>=2.0 ; python_version >= '3.8'").unwrap();
197 assert_eq!(req.name, "requests");
198 assert_eq!(req.specifier, Some(">=2.0".to_string()));
199 assert_eq!(req.marker, Some("python_version >= '3.8'".to_string()));
200 }
201
202 #[test]
203 fn test_parse_pysocks() {
204 let req = Requirement::parse("PySocks!=1.5.7,>=1.5.6; extra == 'socks'").unwrap();
205 assert_eq!(req.name, "PySocks");
206 assert_eq!(req.specifier, Some("!=1.5.7,>=1.5.6".to_string()));
207 assert_eq!(req.marker, Some("extra == 'socks'".to_string()));
208 }
209
210 #[test]
211 fn test_display() {
212 let req = Requirement::new("requests").with_specifier(">=2.0");
213 assert_eq!(req.to_string(), "requests>=2.0");
214 }
215
216 #[test]
217 fn test_display_with_marker() {
218 let req = Requirement::new("requests")
219 .with_specifier(">=2.0")
220 .with_marker("python_version >= '3.8'");
221 assert_eq!(req.to_string(), "requests>=2.0 ; python_version >= '3.8'");
222 }
223}