1use std::fmt;
2
3use anyhow::{Result, bail};
4use serde::{Deserialize, Serialize};
5
6use crate::SandboxError;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(tag = "kind", rename_all = "snake_case")]
10pub enum EnvironmentSourceRef {
11 Gym(GymSourceRef),
12 Hf(HfSourceRef),
13}
14
15impl EnvironmentSourceRef {
16 pub fn parse(value: &str) -> std::result::Result<Self, SandboxError> {
19 Self::parse_inner(value).map_err(SandboxError::invalid_source)
20 }
21
22 fn parse_inner(value: &str) -> Result<Self> {
23 let value = value.trim();
24 if value.is_empty() {
25 bail!("sandbox source must not be empty");
26 }
27
28 if let Some(rest) = value.strip_prefix("gym://") {
29 return Self::parse_gym(rest);
30 }
31
32 if let Some(rest) = value.strip_prefix("hf://") {
33 return Ok(Self::Hf(HfSourceRef::parse(rest)?));
34 }
35
36 if value.contains("://") {
37 bail!("unsupported sandbox source '{value}'");
38 }
39
40 Self::parse_gym(value)
41 }
42
43 fn parse_gym(env_id: &str) -> Result<Self> {
44 let env_id = env_id.trim();
45 if env_id.is_empty() {
46 bail!("gym source must include an environment id");
47 }
48 Ok(Self::Gym(GymSourceRef {
49 env_id: env_id.to_string(),
50 }))
51 }
52
53 pub fn slug(&self) -> String {
54 match self {
55 Self::Gym(source) => sanitize_slug(&source.env_id),
56 Self::Hf(source) => {
57 let mut value = source.repo.replace('/', "-");
58 if let Some(suite) = &source.suite {
59 value.push('-');
60 value.push_str(suite);
61 }
62 if let Some(task) = &source.task {
63 value.push('-');
64 value.push_str(task);
65 }
66 sanitize_slug(&value)
67 }
68 }
69 }
70}
71
72impl fmt::Display for EnvironmentSourceRef {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 match self {
75 Self::Gym(source) => write!(f, "gym://{}", source.env_id),
76 Self::Hf(source) => {
77 write!(f, "hf://{}", source.repo)?;
78 if let Some(revision) = &source.revision {
79 write!(f, "@{revision}")?;
80 }
81 if let Some(suite) = &source.suite {
82 write!(f, ":{suite}")?;
83 }
84 if let Some(task) = &source.task {
85 write!(f, "/{task}")?;
86 }
87 Ok(())
88 }
89 }
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub struct GymSourceRef {
95 pub env_id: String,
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
99pub struct HfSourceRef {
100 pub repo: String,
101 pub revision: Option<String>,
102 pub suite: Option<String>,
103 pub task: Option<String>,
104}
105
106impl HfSourceRef {
107 fn parse(value: &str) -> Result<Self> {
108 let value = value.trim();
109 if value.is_empty() {
110 bail!("hugging face source must include org/repo");
111 }
112
113 let (repo_and_revision, suite, task) = match value.rsplit_once(':') {
114 Some((left, right)) if !left.is_empty() && !right.is_empty() => {
115 let (suite, task) = parse_selector(right)?;
116 (left, Some(suite), task)
117 }
118 _ => (value, None, None),
119 };
120
121 let (repo, revision) = match repo_and_revision.rsplit_once('@') {
122 Some((left, right)) if !left.is_empty() && !right.is_empty() => {
123 (left, Some(validate_ref_part("revision", right)?))
124 }
125 Some(_) => bail!("hugging face revision must look like @revision"),
126 None => (repo_and_revision, None),
127 };
128
129 validate_hf_repo(repo)?;
130
131 Ok(Self {
132 repo: repo.to_string(),
133 revision,
134 suite,
135 task,
136 })
137 }
138}
139
140#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
141#[serde(tag = "kind", rename_all = "snake_case")]
142pub(crate) enum ResolvedEnvironmentSourceRef {
143 Gym(GymSourceRef),
144 Hf(ResolvedHfSourceRef),
145}
146
147impl ResolvedEnvironmentSourceRef {
148 pub(crate) fn slug(&self) -> String {
149 match self {
150 Self::Gym(source) => sanitize_slug(&source.env_id),
151 Self::Hf(source) => {
152 let mut value = source.repo.replace('/', "-");
153 if let Some(suite) = &source.suite {
154 value.push('-');
155 value.push_str(suite);
156 }
157 if let Some(task) = &source.task {
158 value.push('-');
159 value.push_str(task);
160 }
161 sanitize_slug(&value)
162 }
163 }
164 }
165}
166
167impl fmt::Display for ResolvedEnvironmentSourceRef {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 match self {
170 Self::Gym(source) => write!(f, "gym://{}", source.env_id),
171 Self::Hf(source) => {
172 write!(f, "hf://{}@{}", source.repo, source.resolved_revision)?;
173 if let Some(suite) = &source.suite {
174 write!(f, ":{suite}")?;
175 }
176 if let Some(task) = &source.task {
177 write!(f, "/{task}")?;
178 }
179 Ok(())
180 }
181 }
182 }
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
186pub(crate) struct ResolvedHfSourceRef {
187 pub repo: String,
188 pub resolved_revision: String,
189 pub suite: Option<String>,
190 pub task: Option<String>,
191}
192
193pub fn sanitize_slug(value: &str) -> String {
194 let mut slug = String::new();
195 let mut prev_dash = false;
196 for ch in value.chars() {
197 let next = match ch {
198 'a'..='z' | '0'..='9' => ch,
199 'A'..='Z' => ch.to_ascii_lowercase(),
200 _ => '-',
201 };
202
203 if next == '-' {
204 if prev_dash {
205 continue;
206 }
207 prev_dash = true;
208 slug.push(next);
209 } else {
210 prev_dash = false;
211 slug.push(next);
212 }
213 }
214
215 let slug = slug.trim_matches('-');
216 if slug.is_empty() {
217 "env".to_string()
218 } else {
219 slug.to_string()
220 }
221}
222
223fn validate_hf_repo(repo: &str) -> Result<()> {
224 let mut parts = repo.split('/');
225 let Some(owner) = parts.next() else {
226 bail!("hugging face sources must look like hf://org/repo[@revision][:suite[/task]]");
227 };
228 let Some(name) = parts.next() else {
229 bail!("hugging face sources must look like hf://org/repo[@revision][:suite[/task]]");
230 };
231 if parts.next().is_some() || owner.is_empty() || name.is_empty() {
232 bail!("hugging face sources must look like hf://org/repo[@revision][:suite[/task]]");
233 }
234 validate_hf_repo_part("owner", owner)?;
235 validate_hf_repo_part("repo", name)?;
236 Ok(())
237}
238
239fn parse_selector(value: &str) -> Result<(String, Option<String>)> {
240 let (suite, task) = match value.split_once('/') {
241 Some((suite, task)) if !suite.is_empty() && !task.is_empty() && !task.contains('/') => (
242 validate_ref_part("suite", suite)?,
243 Some(validate_ref_part("task", task)?),
244 ),
245 Some(_) => bail!("hugging face selector must look like :suite or :suite/task"),
246 None => (validate_ref_part("suite", value)?, None),
247 };
248 Ok((suite, task))
249}
250
251fn validate_hf_repo_part(label: &str, value: &str) -> Result<()> {
252 validate_ref_part(label, value)?;
253 if value.starts_with(['-', '.']) || value.ends_with(['-', '.']) {
254 bail!("hugging face {label} must not start or end with '-' or '.'");
255 }
256 if value.contains("--") || value.contains("..") {
257 bail!("hugging face {label} must not contain '--' or '..'");
258 }
259 if !value
260 .chars()
261 .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_' | '.'))
262 {
263 bail!("hugging face {label} may only contain ASCII letters, digits, '-', '_', and '.'");
264 }
265 Ok(())
266}
267
268fn validate_ref_part(label: &str, value: &str) -> Result<String> {
269 let value = value.trim();
270 if value.is_empty() {
271 bail!("{label} must not be empty");
272 }
273 if value.contains(char::is_whitespace) {
274 bail!("{label} must not contain whitespace");
275 }
276 if value.starts_with('-') {
279 bail!("{label} must not start with '-'");
280 }
281 Ok(value.to_string())
282}
283
284#[cfg(test)]
285mod tests {
286 use super::{EnvironmentSourceRef, HfSourceRef, sanitize_slug};
287
288 #[test]
289 fn parses_plain_gym_sources() {
290 let source = EnvironmentSourceRef::parse("CartPole-v1").unwrap();
291 match source {
292 EnvironmentSourceRef::Gym(source) => assert_eq!(source.env_id, "CartPole-v1"),
293 _ => panic!("expected gym"),
294 }
295 }
296
297 #[test]
298 fn parses_gym_scheme_sources() {
299 let source = EnvironmentSourceRef::parse("gym://CartPole-v1").unwrap();
300 assert_eq!(source.to_string(), "gym://CartPole-v1");
301 }
302
303 #[test]
304 fn parses_hf_sources() {
305 let source = HfSourceRef::parse("org/repo@main:suite_1").unwrap();
306 assert_eq!(source.repo, "org/repo");
307 assert_eq!(source.revision.as_deref(), Some("main"));
308 assert_eq!(source.suite.as_deref(), Some("suite_1"));
309 assert_eq!(source.task, None);
310 }
311
312 #[test]
313 fn parses_hf_sources_with_suite_and_task() {
314 let source = HfSourceRef::parse("org/repo@main:suite_1/0").unwrap();
315 assert_eq!(source.repo, "org/repo");
316 assert_eq!(source.revision.as_deref(), Some("main"));
317 assert_eq!(source.suite.as_deref(), Some("suite_1"));
318 assert_eq!(source.task.as_deref(), Some("0"));
319 }
320
321 #[test]
322 fn parses_hf_source_refs() {
323 let source = EnvironmentSourceRef::parse("hf://org/repo").unwrap();
324 assert_eq!(source.to_string(), "hf://org/repo");
325
326 let source = EnvironmentSourceRef::parse("hf://org/repo@main:suite_1/0").unwrap();
327 assert_eq!(source.to_string(), "hf://org/repo@main:suite_1/0");
328 }
329
330 #[test]
331 fn hf_slug_includes_suite_and_task() {
332 let source = EnvironmentSourceRef::parse("hf://org/repo@main:suite_1/0").unwrap();
333 assert_eq!(source.slug(), "org-repo-suite-1-0");
334 }
335
336 #[test]
337 fn rejects_malformed_hf_selectors() {
338 let err = EnvironmentSourceRef::parse("hf://org/repo@main:suite/").unwrap_err();
339 assert!(err.to_string().contains(":suite/task"));
340
341 let err = EnvironmentSourceRef::parse("hf://org/repo@main:suite/task/extra").unwrap_err();
342 assert!(err.to_string().contains(":suite/task"));
343 }
344
345 #[test]
346 fn rejects_invalid_hf_sources() {
347 let err = EnvironmentSourceRef::parse("hf://org").unwrap_err();
348 assert!(err.to_string().contains("hf://org/repo"));
349 }
350
351 #[test]
352 fn rejects_suspicious_hf_repo_parts() {
353 let err = EnvironmentSourceRef::parse("hf://org/repo?x=1").unwrap_err();
354 assert!(err.to_string().contains("may only contain ASCII"));
355
356 let err = EnvironmentSourceRef::parse("hf://org/..repo").unwrap_err();
357 assert!(err.to_string().contains("must not start or end"));
358 }
359
360 #[test]
361 fn slug_sanitizes_input() {
362 assert_eq!(sanitize_slug("sai_mujoco:Franka"), "sai-mujoco-franka");
363 }
364}