grafbase_local_backend/
project.rs

1use crate::consts::{DEFAULT_DOT_ENV, DEFAULT_SCHEMA, USER_AGENT};
2use crate::errors::BackendError;
3use async_compression::tokio::bufread::GzipDecoder;
4use async_tar::Archive;
5use common::consts::{GRAFBASE_DIRECTORY_NAME, GRAFBASE_ENV_FILE_NAME, GRAFBASE_SCHEMA_FILE_NAME};
6use common::environment::Environment;
7use http_cache_reqwest::{CACacheManager, Cache, CacheMode, HttpCache};
8use reqwest::{header, Client};
9use reqwest_middleware::ClientBuilder;
10use serde::Deserialize;
11use std::env;
12use std::fs;
13use std::io::{Error as IoError, ErrorKind as IoErrorKind};
14use std::iter::Iterator;
15use std::path::PathBuf;
16use tokio_stream::StreamExt;
17use tokio_util::compat::TokioAsyncReadCompatExt;
18use tokio_util::io::StreamReader;
19use url::Url;
20
21/// initializes a new project in the current or a new directory, optionally from a template
22///
23/// # Errors
24///
25/// ## General
26///
27/// - returns [`BackendError::ReadCurrentDirectory`] if the current directory could not be read
28///
29/// - returns [`BackendError::ProjectDirectoryExists`] if a named is passed and a directory with the same name already exists in the current directory
30///
31/// - returns [`BackendError::AlreadyAProject`] if there's already a grafbase/schema.graphql in the target
32///
33/// - returns [`BackendError::CreateGrafbaseDirectory`] if the grafbase directory could not be created
34///
35/// - returns [`BackendError::CreateProjectDirectory`] if the project directory could not be created
36///
37/// - returns [`BackendError::WriteSchema`] if the schema file could not be written
38///
39/// ## Templates
40///
41/// - returns [`BackendError::UnsupportedTemplateURL`] if a template URL is not supported
42///
43/// - returns [`BackendError::StartDownloadRepoArchive`] if a template URL is not supported (if the request could not be made)
44///
45/// - returns [`BackendError::DownloadRepoArchive`] if a repo tar could not be downloaded (on a non 200-299 status)
46///
47/// - returns [`BackendError::TemplateNotFound`] if no files matching the template path were extracted (excluding extraction errors)
48///
49/// - returns [`BackendError::MoveExtractedFiles`] if the extracted files from the template repository could not be moved
50///
51/// - returns [`BackendError::ReadArchiveEntries`] if the entries of the template repository archive could not be read
52///
53/// - returns [`BackendError::ExtractArchiveEntry`] if one of the entries of the template repository archive could not be extracted
54///
55/// - returns [`BackendError::CleanExtractedFiles`] if the files extracted from the template repository archive could not be cleaned
56///
57/// - returns [`BackendError::StartGetRepositoryInformation`] if the request to get the information for a repository could not be sent
58///
59/// - returns [`BackendError::GetRepositoryInformation`] if the request to get the information for a repository returned a non 200-299 status
60///
61/// - returns [`BackendError::ReadRepositoryInformation`] if the request to get the information for a repository returned a response that could not be parsed
62#[tokio::main]
63pub async fn init(name: Option<&str>, template: Option<&str>) -> Result<(), BackendError> {
64    let project_path = to_project_path(name)?;
65    let grafbase_path = project_path.join(GRAFBASE_DIRECTORY_NAME);
66    let schema_path = grafbase_path.join(GRAFBASE_SCHEMA_FILE_NAME);
67
68    if grafbase_path.exists() {
69        Err(BackendError::AlreadyAProject(grafbase_path))
70    } else if let Some(template) = template {
71        // as directory names cannot contain slashes, and URLs with no scheme or path cannot
72        // be differentiated from a valid template name,
73        // anything with a slash is treated as a URL
74        if template.contains('/') {
75            if let Ok(repo_url) = Url::parse(template) {
76                match repo_url.host_str() {
77                    Some("github.com") => handle_github_repo_url(grafbase_path, &repo_url).await,
78                    _ => Err(BackendError::UnsupportedTemplateURL(template.to_string())),
79                }
80            } else {
81                return Err(BackendError::MalformedTemplateURL(template.to_owned()));
82            }
83        } else {
84            download_github_template(
85                grafbase_path,
86                GitHubTemplate::Grafbase(GrafbaseGithubTemplate { path: template }),
87            )
88            .await
89        }
90    } else {
91        tokio::fs::create_dir_all(&grafbase_path)
92            .await
93            .map_err(BackendError::CreateGrafbaseDirectory)?;
94
95        let dot_env_path = grafbase_path.join(GRAFBASE_ENV_FILE_NAME);
96        let schema_write_result = fs::write(schema_path, DEFAULT_SCHEMA).map_err(BackendError::WriteSchema);
97        let dot_env_write_result = fs::write(dot_env_path, DEFAULT_DOT_ENV).map_err(BackendError::WriteSchema);
98
99        if schema_write_result.is_err() || dot_env_write_result.is_err() {
100            tokio::fs::remove_dir_all(&grafbase_path)
101                .await
102                .map_err(BackendError::DeleteGrafbaseDirectory)?;
103        }
104
105        schema_write_result?;
106        dot_env_write_result?;
107
108        Ok(())
109    }
110}
111
112async fn handle_github_repo_url(grafbase_path: PathBuf, repo_url: &Url) -> Result<(), BackendError> {
113    if let Some(mut segments) = repo_url.path_segments().map(Iterator::collect::<Vec<_>>) {
114        // remove trailing slashes to prevent extra path parameters
115        if segments.last() == Some(&"") {
116            segments.pop();
117        }
118
119        // disallow empty path paramters other than the last
120        if segments.contains(&"") {
121            return Err(BackendError::UnsupportedTemplateURL(repo_url.to_string()));
122        }
123
124        match segments.len() {
125            2 => {
126                let org = &segments[0];
127
128                let repo = &segments[1];
129
130                let branch = get_default_branch(org, repo).await?;
131
132                download_github_template(
133                    grafbase_path,
134                    GitHubTemplate::External(ExternalGitHubTemplate {
135                        org,
136                        repo,
137                        branch: &branch,
138                        path: None,
139                    }),
140                )
141                .await
142            }
143            4.. if segments[2] == "tree" => {
144                let org = &segments[0];
145
146                let repo = &segments[1];
147
148                let branch = &segments[3];
149
150                let path = segments.get(4..).map(|path| path.join("/"));
151
152                download_github_template(
153                    grafbase_path,
154                    GitHubTemplate::External(ExternalGitHubTemplate {
155                        org,
156                        repo,
157                        path,
158                        branch,
159                    }),
160                )
161                .await
162            }
163            _ => Err(BackendError::UnsupportedTemplateURL(repo_url.to_string())),
164        }
165    } else {
166        Err(BackendError::UnsupportedTemplateURL(repo_url.to_string()))
167    }
168}
169
170#[derive(Deserialize)]
171struct RepoInfo {
172    default_branch: String,
173}
174
175async fn get_default_branch(org: &str, repo: &str) -> Result<String, BackendError> {
176    let client = Client::new();
177
178    let response = client
179        .get(format!("https://api.github.com/repos/{org}/{repo}"))
180        // api.github.com requires a user agent header to be present
181        .header(header::USER_AGENT, USER_AGENT)
182        .send()
183        .await
184        .map_err(|_| BackendError::StartGetRepositoryInformation(format!("{org}/{repo}")))?;
185
186    if !response.status().is_success() {
187        return Err(BackendError::GetRepositoryInformation(format!("{org}/{repo}")));
188    }
189
190    let repo_info = response
191        .json::<RepoInfo>()
192        .await
193        .map_err(|_| BackendError::ReadRepositoryInformation(format!("{org}/{repo}")))?;
194
195    Ok(repo_info.default_branch)
196}
197
198fn to_project_path(name: Option<&str>) -> Result<PathBuf, BackendError> {
199    let current_dir = env::current_dir().map_err(|_| BackendError::ReadCurrentDirectory)?;
200    match name {
201        Some(name) => {
202            let project_path = current_dir.join(name);
203            if project_path.exists() {
204                Err(BackendError::ProjectDirectoryExists(project_path))
205            } else {
206                Ok(project_path)
207            }
208        }
209        None => Ok(current_dir),
210    }
211}
212
213#[derive(Clone)]
214struct ExternalGitHubTemplate<'a> {
215    org: &'a str,
216    repo: &'a str,
217    path: Option<String>,
218    branch: &'a str,
219}
220
221struct GrafbaseGithubTemplate<'a> {
222    path: &'a str,
223}
224
225enum GitHubTemplate<'a> {
226    Grafbase(GrafbaseGithubTemplate<'a>),
227    External(ExternalGitHubTemplate<'a>),
228}
229
230impl<'a> GitHubTemplate<'a> {
231    pub fn into_external_github_template(self) -> ExternalGitHubTemplate<'a> {
232        match self {
233            Self::Grafbase(GrafbaseGithubTemplate { path }) => ExternalGitHubTemplate {
234                org: "grafbase",
235                repo: "grafbase",
236                path: Some(format!("templates/{path}")),
237                branch: "main",
238            },
239            Self::External(template @ ExternalGitHubTemplate { .. }) => template,
240        }
241    }
242}
243
244async fn download_github_template(grafbase_path: PathBuf, template: GitHubTemplate<'_>) -> Result<(), BackendError> {
245    let ExternalGitHubTemplate {
246        org,
247        repo,
248        path,
249        branch,
250    } = template.into_external_github_template();
251
252    let org_and_repo = format!("{org}/{repo}");
253
254    let extraction_dir = PathBuf::from(format!("{repo}-{branch}"));
255
256    let mut template_path: PathBuf = PathBuf::from(&extraction_dir);
257
258    if let Some(path) = path {
259        template_path.push(path);
260    }
261
262    template_path.push("grafbase");
263
264    let extraction_result = stream_github_archive(grafbase_path, &org_and_repo, template_path, branch).await;
265
266    if extraction_dir.exists() {
267        tokio::fs::remove_dir_all(extraction_dir)
268            .await
269            .map_err(BackendError::CleanExtractedFiles)?;
270    }
271
272    extraction_result
273}
274
275async fn stream_github_archive<'a>(
276    grafbase_path: PathBuf,
277    org_and_repo: &'a str,
278    template_path: PathBuf,
279    branch: &'a str,
280) -> Result<(), BackendError> {
281    // not using the common environment since it's not initialized here
282    // if the OS does not have a cache path or it is not UTF-8, we don't cache the download
283    let cache_directory = dirs::cache_dir().and_then(|path| path.join("grafbase").to_str().map(ToOwned::to_owned));
284
285    let mut client_builder = ClientBuilder::new(Client::new());
286
287    if let Some(cache_directory) = cache_directory {
288        client_builder = client_builder.with(Cache(HttpCache {
289            mode: CacheMode::Default,
290            manager: CACacheManager { path: cache_directory },
291            options: None,
292        }));
293    }
294
295    let client = client_builder.build();
296
297    let tar_gz_response = client
298        .get(format!("https://codeload.github.com/{org_and_repo}/tar.gz/{branch}"))
299        .send()
300        .await
301        .map_err(|error| BackendError::StartDownloadRepoArchive(org_and_repo.to_owned(), error))?;
302
303    if !tar_gz_response.status().is_success() {
304        return Err(BackendError::DownloadRepoArchive(org_and_repo.to_owned()));
305    }
306
307    let tar_gz_stream = tar_gz_response
308        .bytes_stream()
309        .map(|result| result.map_err(|error| IoError::new(IoErrorKind::Other, error)));
310
311    let tar_gz_reader = StreamReader::new(tar_gz_stream);
312    let tar = GzipDecoder::new(tar_gz_reader);
313    let archive = Archive::new(tar.compat());
314
315    let mut entries = archive.entries().map_err(|_| BackendError::ReadArchiveEntries)?;
316
317    while let Some(entry) = entries.next().await {
318        let mut entry = entry.map_err(BackendError::ExtractArchiveEntry)?;
319
320        if entry
321            .path()
322            .ok()
323            .filter(|path| path.starts_with(&template_path))
324            .is_some()
325        {
326            entry.unpack_in(".").await.map_err(BackendError::ExtractArchiveEntry)?;
327        }
328    }
329
330    if !template_path.exists() {
331        return Err(BackendError::TemplateNotFound);
332    }
333
334    let project_folder = grafbase_path.parent().expect("must exist");
335
336    let named_project = !project_folder.exists();
337
338    if named_project {
339        tokio::fs::create_dir(project_folder)
340            .await
341            .map_err(BackendError::CreateProjectDirectory)?;
342    }
343
344    let rename_result = tokio::fs::rename(template_path, &grafbase_path)
345        .await
346        .map_err(BackendError::MoveExtractedFiles);
347
348    if rename_result.is_err() {
349        tokio::fs::remove_dir_all(project_folder)
350            .await
351            .map_err(BackendError::CleanExtractedFiles)?;
352    }
353
354    rename_result
355}
356
357/// resets the local data for the current project by removing the `.grafbase` directory
358///
359/// # Errors
360///
361/// - returns [`BackendError::ReadCurrentDirectory`] if the current directory cannot be read
362///
363/// - returns [`BackendError::DeleteDatabaseDirectory`] if the `.grafbase` directory cannot be deleted
364pub fn reset() -> Result<(), BackendError> {
365    let environment = Environment::get();
366
367    fs::remove_dir_all(&environment.database_directory_path).map_err(BackendError::DeleteDatabaseDirectory)?;
368
369    Ok(())
370}