Skip to main content

rlmesh_sandbox/
source.rs

1use std::fmt;
2
3use anyhow::{Result, bail};
4use serde::{Deserialize, Serialize};
5
6use crate::SandboxError;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(tag = "kind", rename_all = "snake_case")]
10pub enum EnvironmentSourceRef {
11    Gym(GymSourceRef),
12    Hf(HfSourceRef),
13}
14
15impl EnvironmentSourceRef {
16    /// Parse a sandbox source reference (`gym://<id>`, `hf://<repo>`, or a bare
17    /// gymnasium env id).
18    pub fn parse(value: &str) -> std::result::Result<Self, SandboxError> {
19        Self::parse_inner(value).map_err(SandboxError::invalid_source)
20    }
21
22    fn parse_inner(value: &str) -> Result<Self> {
23        let value = value.trim();
24        if value.is_empty() {
25            bail!("sandbox source must not be empty");
26        }
27
28        if let Some(rest) = value.strip_prefix("gym://") {
29            return Self::parse_gym(rest);
30        }
31
32        if let Some(rest) = value.strip_prefix("hf://") {
33            return Ok(Self::Hf(HfSourceRef::parse(rest)?));
34        }
35
36        if value.contains("://") {
37            bail!("unsupported sandbox source '{value}'");
38        }
39
40        Self::parse_gym(value)
41    }
42
43    fn parse_gym(env_id: &str) -> Result<Self> {
44        let env_id = env_id.trim();
45        if env_id.is_empty() {
46            bail!("gym source must include an environment id");
47        }
48        Ok(Self::Gym(GymSourceRef {
49            env_id: env_id.to_string(),
50        }))
51    }
52
53    pub fn slug(&self) -> String {
54        match self {
55            Self::Gym(source) => sanitize_slug(&source.env_id),
56            Self::Hf(source) => {
57                let mut value = source.repo.replace('/', "-");
58                if let Some(suite) = &source.suite {
59                    value.push('-');
60                    value.push_str(suite);
61                }
62                if let Some(task) = &source.task {
63                    value.push('-');
64                    value.push_str(task);
65                }
66                sanitize_slug(&value)
67            }
68        }
69    }
70}
71
72impl fmt::Display for EnvironmentSourceRef {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        match self {
75            Self::Gym(source) => write!(f, "gym://{}", source.env_id),
76            Self::Hf(source) => {
77                write!(f, "hf://{}", source.repo)?;
78                if let Some(revision) = &source.revision {
79                    write!(f, "@{revision}")?;
80                }
81                if let Some(suite) = &source.suite {
82                    write!(f, ":{suite}")?;
83                }
84                if let Some(task) = &source.task {
85                    write!(f, "/{task}")?;
86                }
87                Ok(())
88            }
89        }
90    }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub struct GymSourceRef {
95    pub env_id: String,
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
99pub struct HfSourceRef {
100    pub repo: String,
101    pub revision: Option<String>,
102    pub suite: Option<String>,
103    pub task: Option<String>,
104}
105
106impl HfSourceRef {
107    fn parse(value: &str) -> Result<Self> {
108        let value = value.trim();
109        if value.is_empty() {
110            bail!("hugging face source must include org/repo");
111        }
112
113        let (repo_and_revision, suite, task) = match value.rsplit_once(':') {
114            Some((left, right)) if !left.is_empty() && !right.is_empty() => {
115                let (suite, task) = parse_selector(right)?;
116                (left, Some(suite), task)
117            }
118            _ => (value, None, None),
119        };
120
121        let (repo, revision) = match repo_and_revision.rsplit_once('@') {
122            Some((left, right)) if !left.is_empty() && !right.is_empty() => {
123                (left, Some(validate_ref_part("revision", right)?))
124            }
125            Some(_) => bail!("hugging face revision must look like @revision"),
126            None => (repo_and_revision, None),
127        };
128
129        validate_hf_repo(repo)?;
130
131        Ok(Self {
132            repo: repo.to_string(),
133            revision,
134            suite,
135            task,
136        })
137    }
138}
139
140#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
141#[serde(tag = "kind", rename_all = "snake_case")]
142pub(crate) enum ResolvedEnvironmentSourceRef {
143    Gym(GymSourceRef),
144    Hf(ResolvedHfSourceRef),
145}
146
147impl ResolvedEnvironmentSourceRef {
148    pub(crate) fn slug(&self) -> String {
149        match self {
150            Self::Gym(source) => sanitize_slug(&source.env_id),
151            Self::Hf(source) => {
152                let mut value = source.repo.replace('/', "-");
153                if let Some(suite) = &source.suite {
154                    value.push('-');
155                    value.push_str(suite);
156                }
157                if let Some(task) = &source.task {
158                    value.push('-');
159                    value.push_str(task);
160                }
161                sanitize_slug(&value)
162            }
163        }
164    }
165}
166
167impl fmt::Display for ResolvedEnvironmentSourceRef {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        match self {
170            Self::Gym(source) => write!(f, "gym://{}", source.env_id),
171            Self::Hf(source) => {
172                write!(f, "hf://{}@{}", source.repo, source.resolved_revision)?;
173                if let Some(suite) = &source.suite {
174                    write!(f, ":{suite}")?;
175                }
176                if let Some(task) = &source.task {
177                    write!(f, "/{task}")?;
178                }
179                Ok(())
180            }
181        }
182    }
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
186pub(crate) struct ResolvedHfSourceRef {
187    pub repo: String,
188    pub resolved_revision: String,
189    pub suite: Option<String>,
190    pub task: Option<String>,
191}
192
193pub fn sanitize_slug(value: &str) -> String {
194    let mut slug = String::new();
195    let mut prev_dash = false;
196    for ch in value.chars() {
197        let next = match ch {
198            'a'..='z' | '0'..='9' => ch,
199            'A'..='Z' => ch.to_ascii_lowercase(),
200            _ => '-',
201        };
202
203        if next == '-' {
204            if prev_dash {
205                continue;
206            }
207            prev_dash = true;
208            slug.push(next);
209        } else {
210            prev_dash = false;
211            slug.push(next);
212        }
213    }
214
215    let slug = slug.trim_matches('-');
216    if slug.is_empty() {
217        "env".to_string()
218    } else {
219        slug.to_string()
220    }
221}
222
223fn validate_hf_repo(repo: &str) -> Result<()> {
224    let mut parts = repo.split('/');
225    let Some(owner) = parts.next() else {
226        bail!("hugging face sources must look like hf://org/repo[@revision][:suite[/task]]");
227    };
228    let Some(name) = parts.next() else {
229        bail!("hugging face sources must look like hf://org/repo[@revision][:suite[/task]]");
230    };
231    if parts.next().is_some() || owner.is_empty() || name.is_empty() {
232        bail!("hugging face sources must look like hf://org/repo[@revision][:suite[/task]]");
233    }
234    validate_hf_repo_part("owner", owner)?;
235    validate_hf_repo_part("repo", name)?;
236    Ok(())
237}
238
239fn parse_selector(value: &str) -> Result<(String, Option<String>)> {
240    let (suite, task) = match value.split_once('/') {
241        Some((suite, task)) if !suite.is_empty() && !task.is_empty() && !task.contains('/') => (
242            validate_ref_part("suite", suite)?,
243            Some(validate_ref_part("task", task)?),
244        ),
245        Some(_) => bail!("hugging face selector must look like :suite or :suite/task"),
246        None => (validate_ref_part("suite", value)?, None),
247    };
248    Ok((suite, task))
249}
250
251fn validate_hf_repo_part(label: &str, value: &str) -> Result<()> {
252    validate_ref_part(label, value)?;
253    if value.starts_with(['-', '.']) || value.ends_with(['-', '.']) {
254        bail!("hugging face {label} must not start or end with '-' or '.'");
255    }
256    if value.contains("--") || value.contains("..") {
257        bail!("hugging face {label} must not contain '--' or '..'");
258    }
259    if !value
260        .chars()
261        .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_' | '.'))
262    {
263        bail!("hugging face {label} may only contain ASCII letters, digits, '-', '_', and '.'");
264    }
265    Ok(())
266}
267
268fn validate_ref_part(label: &str, value: &str) -> Result<String> {
269    let value = value.trim();
270    if value.is_empty() {
271        bail!("{label} must not be empty");
272    }
273    if value.contains(char::is_whitespace) {
274        bail!("{label} must not contain whitespace");
275    }
276    // Reject leading '-' so the value can never be reparsed as a CLI option
277    // when it is later handed to git (e.g. as a revision passed to ls-remote).
278    if value.starts_with('-') {
279        bail!("{label} must not start with '-'");
280    }
281    Ok(value.to_string())
282}
283
284#[cfg(test)]
285mod tests {
286    use super::{EnvironmentSourceRef, HfSourceRef, sanitize_slug};
287
288    #[test]
289    fn parses_plain_gym_sources() {
290        let source = EnvironmentSourceRef::parse("CartPole-v1").unwrap();
291        match source {
292            EnvironmentSourceRef::Gym(source) => assert_eq!(source.env_id, "CartPole-v1"),
293            _ => panic!("expected gym"),
294        }
295    }
296
297    #[test]
298    fn parses_gym_scheme_sources() {
299        let source = EnvironmentSourceRef::parse("gym://CartPole-v1").unwrap();
300        assert_eq!(source.to_string(), "gym://CartPole-v1");
301    }
302
303    #[test]
304    fn parses_hf_sources() {
305        let source = HfSourceRef::parse("org/repo@main:suite_1").unwrap();
306        assert_eq!(source.repo, "org/repo");
307        assert_eq!(source.revision.as_deref(), Some("main"));
308        assert_eq!(source.suite.as_deref(), Some("suite_1"));
309        assert_eq!(source.task, None);
310    }
311
312    #[test]
313    fn parses_hf_sources_with_suite_and_task() {
314        let source = HfSourceRef::parse("org/repo@main:suite_1/0").unwrap();
315        assert_eq!(source.repo, "org/repo");
316        assert_eq!(source.revision.as_deref(), Some("main"));
317        assert_eq!(source.suite.as_deref(), Some("suite_1"));
318        assert_eq!(source.task.as_deref(), Some("0"));
319    }
320
321    #[test]
322    fn parses_hf_source_refs() {
323        let source = EnvironmentSourceRef::parse("hf://org/repo").unwrap();
324        assert_eq!(source.to_string(), "hf://org/repo");
325
326        let source = EnvironmentSourceRef::parse("hf://org/repo@main:suite_1/0").unwrap();
327        assert_eq!(source.to_string(), "hf://org/repo@main:suite_1/0");
328    }
329
330    #[test]
331    fn hf_slug_includes_suite_and_task() {
332        let source = EnvironmentSourceRef::parse("hf://org/repo@main:suite_1/0").unwrap();
333        assert_eq!(source.slug(), "org-repo-suite-1-0");
334    }
335
336    #[test]
337    fn rejects_malformed_hf_selectors() {
338        let err = EnvironmentSourceRef::parse("hf://org/repo@main:suite/").unwrap_err();
339        assert!(err.to_string().contains(":suite/task"));
340
341        let err = EnvironmentSourceRef::parse("hf://org/repo@main:suite/task/extra").unwrap_err();
342        assert!(err.to_string().contains(":suite/task"));
343    }
344
345    #[test]
346    fn rejects_invalid_hf_sources() {
347        let err = EnvironmentSourceRef::parse("hf://org").unwrap_err();
348        assert!(err.to_string().contains("hf://org/repo"));
349    }
350
351    #[test]
352    fn rejects_suspicious_hf_repo_parts() {
353        let err = EnvironmentSourceRef::parse("hf://org/repo?x=1").unwrap_err();
354        assert!(err.to_string().contains("may only contain ASCII"));
355
356        let err = EnvironmentSourceRef::parse("hf://org/..repo").unwrap_err();
357        assert!(err.to_string().contains("must not start or end"));
358    }
359
360    #[test]
361    fn slug_sanitizes_input() {
362        assert_eq!(sanitize_slug("sai_mujoco:Franka"), "sai-mujoco-franka");
363    }
364}