Skip to main content

lux_cli/project/
new.rs

1use std::{error::Error, fmt::Display, path::PathBuf, str::FromStr};
2
3use clap::Args;
4use eyre::{eyre, Result};
5use inquire::{
6    ui::{RenderConfig, Styled},
7    validator::Validation,
8    Confirm, Select, Text,
9};
10use itertools::Itertools;
11use spdx::LicenseId;
12use spinners::{Spinner, Spinners};
13
14use crate::utils::github_metadata::{self, RepoMetadata};
15use lux_lib::{
16    package::PackageReq,
17    project::{Project, PROJECT_TOML},
18};
19
20// TODO:
21// - Automatically detect build type to insert into rockspec by inspecting the current repo.
22//   E.g. if there is a `Cargo.toml` in the project root we can infer the user wants to use the
23//   Rust build backend.
24
25/// The type of directory to create when making the project.
26#[derive(Debug, Clone, clap::ValueEnum)]
27enum SourceDirType {
28    Src,
29    Lua,
30}
31
32impl Display for SourceDirType {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::Src => write!(f, "src"),
36            Self::Lua => write!(f, "lua"),
37        }
38    }
39}
40
41#[derive(Args)]
42pub struct NewProject {
43    /// The directory of the project.
44    target: PathBuf,
45
46    /// The project's name.
47    #[arg(long)]
48    name: Option<String>,
49
50    /// The description of the project.
51    #[arg(long)]
52    description: Option<String>,
53
54    /// The license of the project. Generic license names will be inferred.
55    #[arg(long, value_parser = clap_parse_license)]
56    license: Option<LicenseId>,
57
58    /// The maintainer of this project. Does not have to be the code author.
59    #[arg(long)]
60    maintainer: Option<String>,
61
62    /// A comma-separated list of labels to apply to this project.
63    #[arg(long, value_parser = clap_parse_list)]
64    labels: Option<std::vec::Vec<String>>, // Note: full qualified name required, see https://github.com/clap-rs/clap/issues/4626
65
66    /// A version constraint on the required Lua version for this project.
67    /// Examples: ">=5.1", "5.1"
68    #[arg(long, value_parser = clap_parse_version)]
69    lua_versions: Option<PackageReq>,
70
71    #[arg(long)]
72    main: Option<SourceDirType>,
73}
74
75struct NewProjectValidated {
76    target: PathBuf,
77    name: String,
78    description: String,
79    maintainer: String,
80    labels: Vec<String>,
81    lua_versions: PackageReq,
82    main: SourceDirType,
83    license: Option<LicenseId>,
84}
85
86fn clap_parse_license(s: &str) -> std::result::Result<LicenseId, String> {
87    match validate_license(s) {
88        Ok(Validation::Valid) => unsafe { Ok(parse_license_unchecked(s)) },
89        Err(_) | Ok(Validation::Invalid(_)) => {
90            Err(format!("unable to identify license {s}, please try again!"))
91        }
92    }
93}
94
95fn clap_parse_version(input: &str) -> std::result::Result<PackageReq, String> {
96    PackageReq::from_str(format!("lua {input}").as_str()).map_err(|err| err.to_string())
97}
98
99fn clap_parse_list(input: &str) -> std::result::Result<Vec<String>, String> {
100    if let Some((pos, char)) = input
101        .chars()
102        .find_position(|&c| c != '-' && c != '_' && c != ',' && c.is_ascii_punctuation())
103    {
104        Err(format!(
105            r#"Unexpected punctuation '{char}' found at column {pos}.
106    Lists are comma separated but names should not contain punctuation!"#
107        ))
108    } else {
109        Ok(input.split(',').map(|str| str.trim().to_string()).collect())
110    }
111}
112
113/// Parses a license.
114///
115/// # Security
116///
117/// WARNING: This should only be invoked after validating the license with [`validate_license`].
118unsafe fn parse_license_unchecked(input: &str) -> LicenseId {
119    spdx::imprecise_license_id(input).unwrap_unchecked().0
120}
121
122fn validate_license(input: &str) -> std::result::Result<Validation, Box<dyn Error + Send + Sync>> {
123    if input == "none" {
124        return Ok(Validation::Valid);
125    }
126
127    Ok(
128        match spdx::imprecise_license_id(input).ok_or(format!(
129            "Unable to identify license '{input}', please try again!",
130        )) {
131            Ok(_) => Validation::Valid,
132            Err(err) => Validation::Invalid(err.into()),
133        },
134    )
135}
136
137pub async fn write_project_rockspec(cli_flags: NewProject) -> Result<()> {
138    let project = Project::from_exact(cli_flags.target.clone())?;
139    let render_config = RenderConfig::default_colored()
140        .with_prompt_prefix(Styled::new(">").with_fg(inquire::ui::Color::LightGreen));
141
142    // If the project already exists then ask for override confirmation
143    if project.is_some()
144        && !Confirm::new("Target directory already has a project, write anyway?")
145            .with_default(false)
146            .with_help_message(&format!("This may overwrite your existing {PROJECT_TOML}",))
147            .with_render_config(render_config)
148            .prompt()?
149    {
150        return Err(eyre!("cancelled creation of project (already exists)"));
151    };
152
153    let validated = match cli_flags {
154        // If all parameters are provided then don't bother prompting the user
155        NewProject {
156            description: Some(description),
157            main: Some(main),
158            labels: Some(labels),
159            lua_versions: Some(lua_versions),
160            maintainer: Some(maintainer),
161            name: Some(name),
162            license,
163            target,
164        } => Ok::<_, eyre::Report>(NewProjectValidated {
165            description,
166            labels,
167            license,
168            lua_versions,
169            main,
170            maintainer,
171            name,
172            target,
173        }),
174
175        NewProject {
176            description,
177            labels,
178            license,
179            lua_versions,
180            main,
181            maintainer,
182            name,
183            target,
184        } => {
185            let mut spinner = Spinner::new(
186                Spinners::Dots,
187                "Fetching remote repository metadata... ".into(),
188            );
189
190            let repo_metadata = match github_metadata::get_metadata_for(Some(&target)).await {
191                Ok(value) => value.map_or_else(|| RepoMetadata::default(&target), Ok),
192                Err(_) => {
193                    println!("Could not fetch remote repo metadata, defaulting to empty values.");
194
195                    RepoMetadata::default(&target)
196                }
197            }?;
198
199            spinner.stop_and_persist("✔", "Fetched remote repository metadata.".into());
200
201            let package_name = name.map_or_else(
202                || {
203                    Text::new("Package name:")
204                        .with_default(&repo_metadata.name)
205                        .with_help_message("A folder with the same name will be created for you.")
206                        .with_render_config(render_config)
207                        .prompt()
208                },
209                Ok,
210            )?;
211
212            let description = description.map_or_else(
213                || {
214                    Text::new("Description:")
215                        .with_default(&repo_metadata.description.unwrap_or_default())
216                        .with_render_config(render_config)
217                        .prompt()
218                },
219                Ok,
220            )?;
221
222            let license = license.map_or_else(
223                || {
224                    Ok::<_, eyre::Error>(
225                        match Text::new("License:")
226                            .with_default(&repo_metadata.license.unwrap_or("none".into()))
227                            .with_help_message("Type 'none' for no license")
228                            .with_validator(validate_license)
229                            .with_render_config(render_config)
230                            .prompt()?
231                            .as_str()
232                        {
233                            "none" => None,
234                            license => unsafe { Some(parse_license_unchecked(license)) },
235                        },
236                    )
237                },
238                |license| Ok(Some(license)),
239            )?;
240
241            let labels = labels.or(repo_metadata.labels).map_or_else(
242                || {
243                    Ok::<_, eyre::Error>(
244                        Text::new("Labels:")
245                            .with_placeholder("web,filesystem")
246                            .with_help_message("Labels are comma separated")
247                            .prompt()?
248                            .split(',')
249                            .map(|label| label.trim().to_string())
250                            .collect_vec(),
251                    )
252                },
253                Ok,
254            )?;
255
256            let maintainer = maintainer.map_or_else(
257                || {
258                    let prompt = Text::new("Maintainer:");
259                    if let Some(default_maintainer) = repo_metadata
260                        .contributors
261                        .first()
262                        .cloned()
263                        .or_else(|| whoami::realname().ok())
264                    {
265                        prompt.with_default(&default_maintainer).prompt()
266                    } else {
267                        prompt.prompt()
268                    }
269                },
270                Ok,
271            )?;
272
273            let lua_versions = lua_versions.map_or_else(
274                || {
275                    Ok::<_, eyre::Report>(
276                        format!(
277                            "lua >= {}",
278                            Select::new(
279                                "What is the lowest Lua version you support?",
280                                vec!["5.1", "5.2", "5.3", "5.4", "5.5"]
281                            )
282                            .without_filtering()
283                            .with_vim_mode(true)
284                            .with_help_message(
285                                "This is equivalent to the 'lua >= {version}' constraint."
286                            )
287                            .prompt()?
288                        )
289                        .parse()?,
290                    )
291                },
292                Ok,
293            )?;
294
295            Ok(NewProjectValidated {
296                target,
297                name: package_name,
298                description,
299                labels,
300                license,
301                lua_versions,
302                maintainer,
303                main: main.unwrap_or(SourceDirType::Src),
304            })
305        }
306    }?;
307
308    let _ = std::fs::create_dir_all(&validated.target);
309
310    let rocks_path = validated.target.join(PROJECT_TOML);
311
312    std::fs::write(
313        &rocks_path,
314        format!(
315            r#"
316package = "{package_name}"
317version = "0.1.0"
318lua = "{lua_version_req}"
319
320[description]
321summary = "{summary}"
322maintainer = "{maintainer}"
323labels = [ {labels} ]
324{license}
325
326[dependencies]
327# Add your dependencies here
328# `busted = ">=2.0"`
329
330[run]
331args = [ "{main}/main.lua" ]
332
333[build]
334type = "builtin"
335    "#,
336            package_name = validated.name,
337            summary = validated.description,
338            license = validated
339                .license
340                .map(|license| format!(r#"license = "{}""#, license.name))
341                .unwrap_or_default(),
342            maintainer = validated.maintainer,
343            labels = validated
344                .labels
345                .into_iter()
346                .map(|label| "\"".to_string() + &label + "\"")
347                .join(", "),
348            lua_version_req = validated.lua_versions.version_req(),
349            main = validated.main,
350        )
351        .trim(),
352    )?;
353
354    let main_dir = validated.target.join(validated.main.to_string());
355    if main_dir.exists() {
356        eprintln!(
357            "Directory `{}/` already exists - we won't make any changes to it.",
358            main_dir.display()
359        );
360    } else {
361        std::fs::create_dir(&main_dir)?;
362        std::fs::write(main_dir.join("main.lua"), r#"print("Hello world!")"#)?;
363    }
364
365    println!("All done!");
366
367    Ok(())
368}
369
370// TODO(vhyrro): Add tests